|
3 | 3 | import json |
4 | 4 | import os |
5 | 5 | import typer |
| 6 | +import openai |
6 | 7 | from rich import print |
7 | 8 | from rich.table import Table |
8 | 9 | from rich.console import Console |
9 | 10 |
|
| 11 | +from autoprompt.adapters.base import AgentAdapter |
| 12 | +from autoprompt.adapters.callable import CallableAdapter |
| 13 | +from autoprompt.adapters.cli import CLIAdapter |
| 14 | +from autoprompt.adapters.http import HttpAdapter |
| 15 | +from autoprompt.core.budget import BudgetTracker |
10 | 16 | from autoprompt.core.config import load_config |
11 | 17 | from autoprompt.core.runner import Runner |
| 18 | +from autoprompt.pipeline.calibrator import Calibrator |
| 19 | +from autoprompt.pipeline.labeler import generate_labels |
12 | 20 |
|
13 | 21 | import dotenv |
14 | 22 | dotenv.load_dotenv() |
|
17 | 25 | console = Console() |
18 | 26 |
|
19 | 27 |
|
| 28 | +def _build_adapter(cfg) -> AgentAdapter: |
| 29 | + if cfg.agent.adapter == "http": |
| 30 | + return HttpAdapter(cfg.agent.endpoint) |
| 31 | + if cfg.agent.adapter == "python_callable": |
| 32 | + return CallableAdapter(cfg.agent.import_path) |
| 33 | + if cfg.agent.adapter == "cli": |
| 34 | + return CLIAdapter(cfg.agent.command) |
| 35 | + raise ValueError(f"Unknown adapter: {cfg.agent.adapter}") |
| 36 | + |
| 37 | + |
| 38 | +def _resolve_rubric_path(cfg, config_path: str) -> None: |
| 39 | + config_dir = os.path.dirname(os.path.abspath(config_path)) if config_path else os.getcwd() |
| 40 | + if not os.path.isabs(cfg.rubric.path): |
| 41 | + cfg.rubric.path = os.path.join(config_dir, cfg.rubric.path) |
| 42 | + |
| 43 | + |
| 44 | +def _build_openrouter_client() -> openai.AsyncOpenAI: |
| 45 | + api_key = os.getenv("OPENROUTER_API_KEY") |
| 46 | + if not api_key: |
| 47 | + print("[yellow]Warning: OPENROUTER_API_KEY not set. API calls will fail.[/yellow]") |
| 48 | + return openai.AsyncOpenAI( |
| 49 | + base_url="https://openrouter.ai/api/v1", |
| 50 | + api_key=api_key, |
| 51 | + ) |
| 52 | + |
| 53 | + |
20 | 54 | @app.command() |
21 | 55 | def run(config: str = typer.Argument("autoprompt.yaml", help="Path to config file"), dry_run: bool = False): |
22 | 56 | """Start the AutoPrompt optimization loop.""" |
@@ -169,6 +203,67 @@ def snapshot_lines(snapshot: dict) -> list[str]: |
169 | 203 | print(f"[bold red]Error:[/bold red] {e}") |
170 | 204 |
|
171 | 205 |
|
| 206 | +@app.command() |
| 207 | +def label( |
| 208 | + config: str = typer.Argument("autoprompt.yaml", help="Path to config file"), |
| 209 | + count: int = typer.Option(10, "--count", "-n", help="Number of prompts to generate"), |
| 210 | + out: str = typer.Option("labels.yaml", "--out", "-o", help="Output labels file"), |
| 211 | +): |
| 212 | + """Generate prompts, call the real agent, and collect manual scores interactively.""" |
| 213 | + try: |
| 214 | + cfg = load_config(config) |
| 215 | + adapter = _build_adapter(cfg) |
| 216 | + client = _build_openrouter_client() |
| 217 | + budget = BudgetTracker(cfg.loop.budget_limit_usd) |
| 218 | + config_dir = os.path.dirname(os.path.abspath(config)) if config else os.getcwd() |
| 219 | + |
| 220 | + async def _run_label_flow(): |
| 221 | + healthy = await adapter.health_check() |
| 222 | + if not healthy: |
| 223 | + print("[bold red]Agent is not responding.[/bold red]") |
| 224 | + return |
| 225 | + await generate_labels( |
| 226 | + config=cfg, |
| 227 | + adapter=adapter, |
| 228 | + client=client, |
| 229 | + budget=budget, |
| 230 | + output_path=out, |
| 231 | + count=count, |
| 232 | + config_dir=config_dir, |
| 233 | + ) |
| 234 | + |
| 235 | + asyncio.run(_run_label_flow()) |
| 236 | + except Exception as e: |
| 237 | + import traceback |
| 238 | + print(f"[bold red]Error:[/bold red] {e}") |
| 239 | + traceback.print_exc() |
| 240 | + |
| 241 | + |
| 242 | +@app.command() |
| 243 | +def calibrate( |
| 244 | + config: str = typer.Argument("autoprompt.yaml", help="Path to config file"), |
| 245 | + labels: str = typer.Option("labels.yaml", "--labels", "-l", help="Path to labels YAML"), |
| 246 | +): |
| 247 | + """Compare LLM-judge scores against manual labels and print calibration report.""" |
| 248 | + try: |
| 249 | + cfg = load_config(config) |
| 250 | + _resolve_rubric_path(cfg, config) |
| 251 | + client = _build_openrouter_client() |
| 252 | + budget = BudgetTracker(cfg.loop.budget_limit_usd) |
| 253 | + calibrator = Calibrator(client, budget) |
| 254 | + |
| 255 | + async def _run_calibration(): |
| 256 | + results = await calibrator.run(labels, cfg.rubric) |
| 257 | + calibrator.report(results, cfg.rubric) |
| 258 | + print(f"[dim]Cost so far: ${budget.current_cost_usd:.4f}[/dim]") |
| 259 | + |
| 260 | + asyncio.run(_run_calibration()) |
| 261 | + except Exception as e: |
| 262 | + import traceback |
| 263 | + print(f"[bold red]Error:[/bold red] {e}") |
| 264 | + traceback.print_exc() |
| 265 | + |
| 266 | + |
172 | 267 |
|
173 | 268 | if __name__ == "__main__": |
174 | 269 | app() |
0 commit comments