From e1d1fe114a6a7c8b74b8f7ba6e63ce9acb33e2a2 Mon Sep 17 00:00:00 2001 From: Gowtham Rao MD PhD Date: Thu, 29 Jan 2026 06:39:46 -0500 Subject: [PATCH] Containerize coreason-optimizer as an Optimization Microservice (#50) * feat: containerize coreason-optimizer as microservice - Add fastapi and uvicorn dependencies - Define AgentDefinition and OptimizationRequest schemas - Implement server.py with DynamicConstruct adapter and Bridged clients - Update Dockerfile to run uvicorn - Add server tests Co-authored-by: gowthamrao <13936600+gowthamrao@users.noreply.github.com> * feat: containerize coreason-optimizer as microservice - Add fastapi and uvicorn dependencies - Define AgentDefinition and OptimizationRequest schemas - Implement server.py with DynamicConstruct adapter and Bridged clients - Update Dockerfile to run uvicorn - Add server tests - Fix linting and typing issues in server code Co-authored-by: gowthamrao <13936600+gowthamrao@users.noreply.github.com> * fix(ci): address linting, typing, and coverage issues in server implementation - Fix `B904` linting errors in `server.py` by adding `from e`. - Fix Mypy `no-any-return` errors in `server.py` using `cast`. - Fix Mypy errors in `tests/test_server_basic.py` (return types, indexing). - Add `test_optimize_semantic_success` to improve coverage of the semantic selector path in `server.py`. - Apply code formatting via `ruff`. - Achieve 100% test coverage. Co-authored-by: gowthamrao <13936600+gowthamrao@users.noreply.github.com> * feat: containerize coreason-optimizer as microservice - Add fastapi and uvicorn dependencies - Define AgentDefinition and OptimizationRequest schemas - Implement server.py with DynamicConstruct adapter and Bridged clients - Update Dockerfile to run uvicorn - Add server tests - Fix linting and typing issues in server code (B904, Mypy) - Apply ruff formatting and sorting imports Co-authored-by: gowthamrao <13936600+gowthamrao@users.noreply.github.com> --- Dockerfile | 2 + README.md | 53 +++--- docs/requirements.md | 35 ++++ docs/usage.md | 130 +++++++++++++ poetry.lock | 75 +++++++- pyproject.toml | 4 +- requirements.txt | 2 + src/coreason_optimizer/server.py | 186 ++++++++++++++++++ src/coreason_optimizer/server_schemas.py | 43 +++++ tests/test_server_basic.py | 228 +++++++++++++++++++++++ 10 files changed, 732 insertions(+), 26 deletions(-) create mode 100644 docs/requirements.md create mode 100644 docs/usage.md create mode 100644 src/coreason_optimizer/server.py create mode 100644 src/coreason_optimizer/server_schemas.py create mode 100644 tests/test_server_basic.py diff --git a/Dockerfile b/Dockerfile index 09828c8..fadeffc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,3 +35,5 @@ COPY --from=builder /wheels /wheels # Install the application wheel RUN pip install --no-cache-dir /wheels/*.whl + +CMD ["uvicorn", "coreason_optimizer.server:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index d601cf2..d10e901 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ pip install coreason-optimizer ## Features - **Automated Optimization:** Rewrites instructions and selects examples to maximize a score, not human intuition. +- **Optimization-as-a-Service:** Run as a microservice API to compile prompts on-demand. - **Model-Specific Compilation:** Generates optimized prompts specifically tuned for target models (e.g., GPT-4, Claude 3.5). - **Continuous Learning:** Re-runs optimization on recent logs to patch prompts against data drift. - **Mutate-Evaluate Loop:** Systematic cycle of drafting, evaluating, diagnosing, mutating, and selecting prompts. @@ -30,40 +31,44 @@ For full product requirements, see [docs/product_requirements.md](docs/product_r ## Usage -Here is how to initialize and use the library to compile an agent: +You can use `coreason-optimizer` as a Python library, a CLI tool, or a Microservice. + +### 1. Python Library ```python from coreason_optimizer import OptimizerConfig, PromptOptimizer from coreason_optimizer.core.interfaces import Construct from coreason_optimizer.data import Dataset -# 1. Configuration -config = OptimizerConfig( - target_model="gpt-4o", - metric="exact_match", - max_rounds=10 -) - -# 2. Load Data -dataset = Dataset.from_csv("data/gold_set.csv") -train_set, val_set = dataset.split(test_size=0.2) - -# 3. Load Agent (Construct) -# In a real scenario, this would be imported from your agent code -# from src.agents.analyst import analyst_agent +# Define Agent class MockAgent(Construct): inputs = ["question"] outputs = ["answer"] system_prompt = "You are a helpful assistant." agent = MockAgent() -# 4. Compile -optimizer = PromptOptimizer(config=config) -optimized_manifest = optimizer.compile( - agent=agent, - trainset=train_set, - valset=val_set -) +# Compile +dataset = Dataset.from_csv("data/gold_set.csv") +train_set, val_set = dataset.split(train_ratio=0.8) + +optimizer = PromptOptimizer(config=OptimizerConfig(target_model="gpt-4o")) +manifest = optimizer.compile(agent, train_set, val_set) + +print(f"Optimized Score: {manifest.performance_metric}") +``` + +### 2. Server Mode (Microservice) + +Run the optimizer as a standalone service using Docker: + +```bash +docker run -p 8000:8000 -e OPENAI_API_KEY=$OPENAI_API_KEY coreason-optimizer:latest +``` + +Then call the API: + +```bash +curl -X POST http://localhost:8000/optimize -d @request.json +``` -print(f"Optimization complete. New Score: {optimized_manifest.performance_metric}") -print(f"Optimized Instruction: {optimized_manifest.optimized_instruction}") +For detailed instructions, see **[docs/usage.md](docs/usage.md)**. diff --git a/docs/requirements.md b/docs/requirements.md new file mode 100644 index 0000000..edc0cdb --- /dev/null +++ b/docs/requirements.md @@ -0,0 +1,35 @@ +# Requirements + +## Runtime Dependencies + +The following packages are required for `coreason-optimizer` to function: + +* `python >= 3.12` +* `click >= 8.0.0` +* `jinja2 >= 3.0.0` +* `loguru >= 0.6.0` +* `numpy >= 1.20.0` +* `openai >= 1.0.0` +* `pydantic >= 2.0.0` +* `scikit-learn >= 1.0.0` +* `coreason-identity >= 0.4.1` + +### Server Mode (Microservice) + +For running the optimization microservice (Server Mode), the following additional dependencies are required: + +* `fastapi >= 0.100.0` +* `uvicorn >= 0.20.0` +* `httpx` +* `anyio` + +## Development Dependencies + +For development and testing: + +* `pytest` +* `ruff` +* `pre-commit` +* `pytest-cov` +* `mkdocs` +* `mkdocs-material` diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..a55e5e5 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,130 @@ +# Usage Guide + +`coreason-optimizer` offers three ways to compile and optimize agents: +1. **Python Library:** Integrate directly into your Python code. +2. **CLI:** Use the command-line tool for scripts and CI/CD. +3. **Server Mode (Microservice):** Run as a standalone API service. + +--- + +## 1. Python Library + +Use the library to programmatically compile agents. + +```python +from coreason_optimizer import OptimizerConfig, PromptOptimizer +from coreason_optimizer.core.interfaces import Construct +from coreason_optimizer.data import Dataset + +# 1. Define your Agent +class MyAgent(Construct): + inputs = ["user_query"] + outputs = ["response", "thought_trace"] + system_prompt = "You are a helpful AI assistant." + +agent = MyAgent() + +# 2. Load Data +dataset = Dataset.from_csv("data/training_data.csv") +train_set, val_set = dataset.split(train_ratio=0.8) + +# 3. Configure Optimizer +config = OptimizerConfig( + target_model="gpt-4o", + metric="exact_match", + budget_limit_usd=5.0 +) + +# 4. Compile +optimizer = PromptOptimizer(config=config) +manifest = optimizer.compile(agent, train_set, val_set) + +print(f"Optimization Score: {manifest.performance_metric}") +``` + +--- + +## 2. Command Line Interface (CLI) + +The `coreason-opt` CLI allows you to run optimization jobs from the shell. + +### Tune an Agent + +```bash +coreason-opt tune \ + --agent src/agents/analyst.py \ + --dataset data/gold_set.csv \ + --strategy mipro \ + --output optimized_analyst.json +``` + +### Evaluate a Manifest + +```bash +coreason-opt evaluate \ + --manifest optimized_analyst.json \ + --dataset data/test_set.csv +``` + +--- + +## 3. Server Mode (Optimization-as-a-Service) + +You can run `coreason-optimizer` as a containerized microservice that accepts optimization requests via HTTP. + +### Starting the Server + +**Using Docker:** + +```bash +docker build -t coreason-optimizer:latest . +docker run -p 8000:8000 -e OPENAI_API_KEY=sk-... coreason-optimizer:latest +``` + +**Using Uvicorn (Locally):** + +```bash +uvicorn coreason_optimizer.server:app --host 0.0.0.0 --port 8000 +``` + +### API Endpoints + +#### `POST /optimize` + +Submits an optimization job. + +**Request Body (JSON):** + +```json +{ + "agent": { + "system_prompt": "You are a specialized medical analyst...", + "inputs": ["patient_notes"], + "outputs": ["diagnosis_code"] + }, + "dataset": [ + { + "inputs": {"patient_notes": "Patient complains of..."}, + "reference": "E11.9", + "metadata": {"source": "manual_review"} + }, + ... + ], + "config": { + "target_model": "gpt-4o", + "metric": "exact_match", + "budget_limit_usd": 10.0 + }, + "strategy": "mipro" +} +``` + +**Response:** + +Returns an `OptimizedManifest` JSON object containing the optimized instruction and selected few-shot examples. + +#### `GET /health` + +Checks service health. + +**Response:** `{"status": "ready"}` diff --git a/poetry.lock b/poetry.lock index 6dee24b..50412ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,6 +12,18 @@ files = [ {file = "aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2"}, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +description = "Document parameters, class attributes, return types, and variables inline, with Annotated." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320"}, + {file = "annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4"}, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -628,6 +640,29 @@ files = [ dnspython = ">=2.0.0" idna = ">=2.0.0" +[[package]] +name = "fastapi" +version = "0.128.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"}, + {file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"}, +] + +[package.dependencies] +annotated-doc = ">=0.0.2" +pydantic = ">=2.7.0" +starlette = ">=0.40.0,<0.51.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] +standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "filelock" version = "3.20.3" @@ -2030,6 +2065,25 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "starlette" +version = "0.50.0" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca"}, + {file = "starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca"}, +] + +[package.dependencies] +anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} + +[package.extras] +full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] + [[package]] name = "threadpoolctl" version = "3.6.0" @@ -2122,6 +2176,25 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""] +[[package]] +name = "uvicorn" +version = "0.40.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee"}, + {file = "uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "virtualenv" version = "20.36.1" @@ -2225,4 +2298,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.12, <3.15" -content-hash = "9bf8e3203edcb35feda66fd762450d58787a725746b6358aeebd156390994071" +content-hash = "32a676fff1e40493b63349f4aba54761d1b02e958a49ed7350f7909efdff1bf6" diff --git a/pyproject.toml b/pyproject.toml index 88cf4cd..2650e6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "coreason_optimizer" -version = "0.2.1" +version = "0.3.0" description = "coreason-optimizer" authors = ["Gowtham A Rao "] license = "Prosperity-3.0" @@ -21,6 +21,8 @@ httpx = "^0.28.1" aiofiles = "*" types-aiofiles = "*" coreason-identity = "^0.4.1" +fastapi = "^0.128.0" +uvicorn = "^0.40.0" [tool.poetry.scripts] coreason-opt = "coreason_optimizer.main:cli" diff --git a/requirements.txt b/requirements.txt index 669f6c4..8d1c401 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ numpy>=1.20.0 openai>=1.0.0 pydantic>=2.0.0 scikit-learn>=1.0.0 +fastapi>=0.100.0 +uvicorn>=0.20.0 diff --git a/src/coreason_optimizer/server.py b/src/coreason_optimizer/server.py new file mode 100644 index 0000000..0e614d7 --- /dev/null +++ b/src/coreason_optimizer/server.py @@ -0,0 +1,186 @@ +# Copyright (c) 2025 CoReason, Inc. +# +# This software is proprietary and dual-licensed. +# Licensed under the Prosperity Public License 3.0 (the "License"). +# A copy of the license is available at https://prosperitylicense.com/versions/3.0.0 +# For details, see the LICENSE file. +# Commercial use beyond a 30-day trial requires a separate license. +# +# Source Code: https://github.com/CoReason-AI/coreason_optimizer + +""" +FastAPI Server implementation for the Coreason Optimization Microservice. +""" + +import contextlib +from typing import Any, AsyncIterator, cast + +import anyio +import httpx +from fastapi import FastAPI, HTTPException, Request +from openai import AsyncOpenAI + +from coreason_optimizer.core.client import OpenAIClientAsync, OpenAIEmbeddingClientAsync +from coreason_optimizer.core.interfaces import ( + AsyncEmbeddingProvider, + AsyncLLMClient, + EmbeddingResponse, + LLMResponse, +) +from coreason_optimizer.core.metrics import MetricFactory +from coreason_optimizer.data.loader import Dataset +from coreason_optimizer.server_schemas import AgentDefinition, OptimizationRequest +from coreason_optimizer.strategies.bootstrap import BootstrapFewShot +from coreason_optimizer.strategies.mipro import MiproOptimizer + +# --- Adapters --- + + +class DynamicConstruct: + """ + Adapter to make AgentDefinition satisfy the Construct protocol. + """ + + def __init__(self, agent_def: AgentDefinition): + self._agent = agent_def + + @property + def inputs(self) -> list[str]: + return self._agent.inputs + + @property + def outputs(self) -> list[str]: + return self._agent.outputs + + @property + def system_prompt(self) -> str: + return self._agent.system_prompt + + +class BridgedLLMClient: + """ + Sync wrapper that bridges calls to an AsyncLLMClient running in the main loop. + + This allows synchronous strategies to use async clients managed by FastAPI. + """ + + def __init__(self, async_client: AsyncLLMClient): + self.async_client = async_client + + def generate( + self, + messages: list[dict[str, str]], + model: str | None = None, + temperature: float = 0.0, + **kwargs: Any, + ) -> LLMResponse: + async def _call() -> LLMResponse: + return await self.async_client.generate(messages=messages, model=model, temperature=temperature, **kwargs) + + # Dispatch to the main event loop + return cast(LLMResponse, anyio.from_thread.run(_call)) + + +class BridgedEmbeddingProvider: + """ + Sync wrapper that bridges calls to an AsyncEmbeddingProvider running in the main loop. + """ + + def __init__(self, async_provider: AsyncEmbeddingProvider): + self.async_provider = async_provider + + def embed(self, texts: list[str], model: str | None = None) -> EmbeddingResponse: + async def _call() -> EmbeddingResponse: + return await self.async_provider.embed(texts=texts, model=model) + + # Dispatch to the main event loop + return cast(EmbeddingResponse, anyio.from_thread.run(_call)) + + +# --- Lifespan --- + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + """ + Manage the lifecycle of shared clients. + """ + # Initialize shared clients + # We use a single httpx client for connection pooling + http_client = httpx.AsyncClient() + + # Initialize AsyncOpenAI with the shared http client + # This expects OPENAI_API_KEY to be present in environment variables + openai_client = AsyncOpenAI(http_client=http_client) + + app.state.http_client = http_client + app.state.openai_client = openai_client + + # Create Coreason Async wrappers using the shared clients + # We pass the shared clients so they are reused + app.state.llm_client_async = OpenAIClientAsync(client=openai_client, http_client=http_client) + app.state.embedding_client_async = OpenAIEmbeddingClientAsync(client=openai_client, http_client=http_client) + + yield + + # Cleanup + await openai_client.close() + await http_client.aclose() + + +# --- Server --- + +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +def health() -> dict[str, str]: + return {"status": "ready"} + + +@app.post("/optimize") +def optimize(request: Request, body: OptimizationRequest) -> Any: + # 1. Adapter: Convert AgentDefinition to Construct + agent = DynamicConstruct(body.agent) + + # 2. Dataset: Load and split + # We use the Dataset class to handle splitting logic + dataset = Dataset(body.dataset) + # Default 80/20 split as requested + train_set, val_set, _ = dataset.split(train_ratio=0.8, val_ratio=0.2) + train_list = list(train_set) + val_list = list(val_set) + + # 3. Clients: Bridge to shared async clients + if not hasattr(request.app.state, "llm_client_async"): + raise HTTPException(status_code=500, detail="LLM Client not initialized") + + llm_client = BridgedLLMClient(request.app.state.llm_client_async) + + embedding_provider = None + if body.config.selector_type == "semantic": + if not hasattr(request.app.state, "embedding_client_async"): + raise HTTPException(status_code=500, detail="Embedding Client not initialized") + embedding_provider = BridgedEmbeddingProvider(request.app.state.embedding_client_async) + + # 4. Metric + try: + metric = MetricFactory.get(body.config.metric) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + # 5. Strategy: Initialize the optimizer + optimizer: Any + if body.strategy == "bootstrap": + optimizer = BootstrapFewShot(llm_client, metric, body.config) + else: + # Default to Mipro + optimizer = MiproOptimizer(llm_client, metric, body.config, embedding_provider=embedding_provider) + + # 6. Run Compilation + try: + manifest = optimizer.compile(agent, train_list, val_list) + return manifest + except Exception as e: + # In production, we should log the full traceback + raise HTTPException(status_code=500, detail=f"Optimization failed: {str(e)}") from e diff --git a/src/coreason_optimizer/server_schemas.py b/src/coreason_optimizer/server_schemas.py new file mode 100644 index 0000000..fff7e00 --- /dev/null +++ b/src/coreason_optimizer/server_schemas.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 CoReason, Inc. +# +# This software is proprietary and dual-licensed. +# Licensed under the Prosperity Public License 3.0 (the "License"). +# A copy of the license is available at https://prosperitylicense.com/versions/3.0.0 +# For details, see the LICENSE file. +# Commercial use beyond a 30-day trial requires a separate license. +# +# Source Code: https://github.com/CoReason-AI/coreason_optimizer + +""" +API Schemas for the Optimization Microservice. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +from coreason_optimizer.core.config import OptimizerConfig +from coreason_optimizer.core.models import TrainingExample + + +class AgentDefinition(BaseModel): + """ + Schema for an agent definition in the Optimization request. + + Mirrors the Construct protocol properties. + """ + + system_prompt: str = Field(..., description="The initial system prompt text.") + inputs: list[str] = Field(..., description="List of input field names.") + outputs: list[str] = Field(..., description="List of output field names.") + + +class OptimizationRequest(BaseModel): + """ + Request schema for the optimization endpoint. + """ + + agent: AgentDefinition = Field(..., description="The agent to optimize.") + dataset: list[TrainingExample] = Field(..., description="List of training examples.") + config: OptimizerConfig = Field(default_factory=OptimizerConfig, description="Optimization configuration.") + strategy: Literal["mipro", "bootstrap"] = Field(default="mipro", description="Optimization strategy to use.") diff --git a/tests/test_server_basic.py b/tests/test_server_basic.py new file mode 100644 index 0000000..2e60190 --- /dev/null +++ b/tests/test_server_basic.py @@ -0,0 +1,228 @@ +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +# Set dummy API key for tests +os.environ["OPENAI_API_KEY"] = "sk-dummy-key" + +from coreason_optimizer.core.models import OptimizedManifest +from coreason_optimizer.server import app + + +def test_health() -> None: + with TestClient(app) as client: + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ready"} + + +def test_optimize_schema_validation() -> None: + with TestClient(app) as client: + # Invalid request (missing fields) + response = client.post("/optimize", json={}) + assert response.status_code == 422 + + +@patch("coreason_optimizer.server.MiproOptimizer") +@patch("coreason_optimizer.server.BootstrapFewShot") +@patch("coreason_optimizer.server.MetricFactory") +def test_optimize_endpoint(mock_metric: MagicMock, mock_bootstrap: MagicMock, mock_mipro: MagicMock) -> None: + # Setup mocks + mock_optimizer_instance = MagicMock() + mock_mipro.return_value = mock_optimizer_instance + + mock_manifest = OptimizedManifest( + agent_id="test_agent", + base_model="gpt-4o", + optimized_instruction="Optimized system prompt", + few_shot_examples=[], + performance_metric=1.0, + optimization_run_id="test_run", + ) + mock_optimizer_instance.compile.return_value = mock_manifest + + payload = { + "agent": {"system_prompt": "Original prompt", "inputs": ["input1"], "outputs": ["output1"]}, + "dataset": [ + {"inputs": {"input1": "val1"}, "reference": "ref1", "metadata": {}}, + {"inputs": {"input1": "val2"}, "reference": "ref2", "metadata": {}}, + ], + "config": {"metric": "exact_match", "target_model": "gpt-4o"}, + "strategy": "mipro", + } + + with TestClient(app) as client: + response = client.post("/optimize", json=payload) + + # If the response is not 200, print details + if response.status_code != 200: + print(response.json()) + + assert response.status_code == 200 + data = response.json() + assert data["agent_id"] == "test_agent" + assert data["optimized_instruction"] == "Optimized system prompt" + + # Verify Mipro was called + mock_mipro.assert_called_once() + mock_optimizer_instance.compile.assert_called_once() + + +def test_dynamic_construct() -> None: + from coreason_optimizer.server import DynamicConstruct + from coreason_optimizer.server_schemas import AgentDefinition + + ad = AgentDefinition(system_prompt="sys", inputs=["i"], outputs=["o"]) + dc = DynamicConstruct(ad) + assert dc.system_prompt == "sys" + assert dc.inputs == ["i"] + assert dc.outputs == ["o"] + + +@pytest.mark.asyncio +async def test_bridged_client() -> None: + import anyio + + from coreason_optimizer.server import BridgedLLMClient + + mock_async = AsyncMock() + mock_async.generate.return_value = "response" + + bridge = BridgedLLMClient(mock_async) + + def worker() -> str: + return str(bridge.generate(messages=[])) + + result = await anyio.to_thread.run_sync(worker) + assert result == "response" + mock_async.generate.assert_called_once() + + +@pytest.mark.asyncio +async def test_bridged_embedding_provider() -> None: + import anyio + + from coreason_optimizer.server import BridgedEmbeddingProvider + + mock_async = AsyncMock() + mock_async.embed.return_value = "embeddings" + + bridge = BridgedEmbeddingProvider(mock_async) + + def worker() -> str: + return str(bridge.embed(texts=[])) + + result = await anyio.to_thread.run_sync(worker) + assert result == "embeddings" + mock_async.embed.assert_called_once() + + +def test_optimize_errors_and_bootstrap() -> None: + payload: dict[str, Any] = { + "agent": {"system_prompt": "Original prompt", "inputs": ["input1"], "outputs": ["output1"]}, + "dataset": [{"inputs": {"input1": "val1"}, "reference": "ref1", "metadata": {}}], + "config": {"metric": "unknown_metric", "target_model": "gpt-4o"}, + "strategy": "mipro", + } + + # 1. Unknown metric + with TestClient(app) as client: + response = client.post("/optimize", json=payload) + assert response.status_code == 400 + assert "Unknown metric" in response.text + + # 2. Bootstrap strategy + payload["config"]["metric"] = "exact_match" + payload["strategy"] = "bootstrap" + + with patch("coreason_optimizer.server.BootstrapFewShot") as mock_boot: + mock_instance = MagicMock() + mock_boot.return_value = mock_instance + # Mock compile return + mock_instance.compile.return_value = OptimizedManifest( + agent_id="test", + base_model="gpt", + optimized_instruction="sys", + performance_metric=1.0, + optimization_run_id="id", + ) + + with TestClient(app) as client: + response = client.post("/optimize", json=payload) + assert response.status_code == 200 + mock_boot.assert_called_once() + + # 3. Exception handling + payload["strategy"] = "mipro" + with patch("coreason_optimizer.server.MiproOptimizer") as mock_mipro: + mock_mipro.return_value.compile.side_effect = Exception("Boom") + with TestClient(app) as client: + response = client.post("/optimize", json=payload) + assert response.status_code == 500 + assert "Boom" in response.text + + +def test_missing_state() -> None: + payload: dict[str, Any] = { + "agent": {"system_prompt": "Original prompt", "inputs": ["input1"], "outputs": ["output1"]}, + "dataset": [{"inputs": {"input1": "val1"}, "reference": "ref1", "metadata": {}}], + "config": {"metric": "exact_match", "selector_type": "random"}, + "strategy": "mipro", + } + + with TestClient(app) as client: + # 1. Missing LLM Client + # We need to temporarily remove the attr from app.state + # app.state is available via client.app.state + llm = client.app.state.llm_client_async + del client.app.state.llm_client_async + + response = client.post("/optimize", json=payload) + assert response.status_code == 500 + assert "LLM Client not initialized" in response.text + + # Restore + client.app.state.llm_client_async = llm + + # 2. Missing Embedding Client (when selector is semantic) + payload["config"]["selector_type"] = "semantic" + embed = client.app.state.embedding_client_async + del client.app.state.embedding_client_async + + response = client.post("/optimize", json=payload) + assert response.status_code == 500 + assert "Embedding Client not initialized" in response.text + + # Restore + client.app.state.embedding_client_async = embed + + +@patch("coreason_optimizer.server.MiproOptimizer") +@patch("coreason_optimizer.server.MetricFactory") +def test_optimize_semantic_success(mock_metric: MagicMock, mock_mipro: MagicMock) -> None: + mock_optimizer_instance = MagicMock() + mock_mipro.return_value = mock_optimizer_instance + mock_optimizer_instance.compile.return_value = OptimizedManifest( + agent_id="test", base_model="gpt", optimized_instruction="sys", performance_metric=1.0, optimization_run_id="id" + ) + + payload = { + "agent": {"system_prompt": "Original prompt", "inputs": ["input1"], "outputs": ["output1"]}, + "dataset": [{"inputs": {"input1": "val1"}, "reference": "ref1", "metadata": {}}], + "config": {"metric": "exact_match", "selector_type": "semantic"}, + "strategy": "mipro", + } + + with TestClient(app) as client: + # Ensure embedding client exists in state (it should by default in lifespan) + assert hasattr(client.app.state, "embedding_client_async") + + response = client.post("/optimize", json=payload) + assert response.status_code == 200 + + # Verify Mipro was initialized with an embedding provider + args, kwargs = mock_mipro.call_args + assert kwargs.get("embedding_provider") is not None