diff --git a/.gitignore b/.gitignore index 42d2b831..d4f0d2df 100644 --- a/.gitignore +++ b/.gitignore @@ -105,6 +105,9 @@ env.bak/ venv.bak/ *.backup +# Secrets +secrets.yaml + # Spyder project settings .spyderproject .spyproject diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 6b96426f..2906337d 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -70,6 +70,14 @@ except ImportError: WeaveAdapter = None +try: + from .proxy import create_app, AuthProvider, AccountInfo +except ImportError: + create_app = None + AuthProvider = None + AccountInfo = None + + warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") __all__ = [ @@ -130,6 +138,10 @@ "RolloutMetadata", "StatusResponse", "create_langfuse_config_tags", + # Proxy + "create_app", + "AuthProvider", + "AccountInfo", ] from . import _version diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index af0bf30c..707f983a 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -7,9 +7,9 @@ from __future__ import annotations import logging import requests -import time from datetime import datetime from typing import Any, Dict, List, Optional, Protocol +import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter @@ -343,15 +343,17 @@ def get_evaluation_rows( # Remove None values params = {k: v for k, v in params.items() if v is not None} - # Make request to proxy + # Make request to proxy (using pointwise for efficiency) if self.project_id: - url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" + url = f"{self.base_url}/v1/project_id/{self.project_id}/traces/pointwise" else: - url = f"{self.base_url}/v1/traces" + url = f"{self.base_url}/v1/traces/pointwise" + + headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} result = None try: - response = requests.get(url, params=params, timeout=self.timeout) + response = requests.get(url, params=params, timeout=self.timeout, headers=headers) response.raise_for_status() result = response.json() except requests.exceptions.HTTPError as e: @@ -365,7 +367,7 @@ def get_evaluation_rows( except Exception: # In case e.response.json() fails error_msg = f"Proxy error: {e.response.text}" - logger.error("Failed to fetch traces from proxy: %s", error_msg) + logger.error("Failed to fetch traces from proxy (HTTP %s): %s", e.response.status_code, error_msg) return eval_rows except requests.exceptions.RequestException as e: # Non-HTTP errors (network issues, timeouts, etc.) diff --git a/eval_protocol/proxy/.env.example b/eval_protocol/proxy/.env.example new file mode 100644 index 00000000..1a6eb490 --- /dev/null +++ b/eval_protocol/proxy/.env.example @@ -0,0 +1,2 @@ +# In order to set other model providers keys for proxy, make a copy, rename to .env, and fill here +OPENAI_API_KEY=sk-proj-xxx diff --git a/eval_protocol/proxy/Dockerfile.gateway b/eval_protocol/proxy/Dockerfile.gateway new file mode 100644 index 00000000..a9308faa --- /dev/null +++ b/eval_protocol/proxy/Dockerfile.gateway @@ -0,0 +1,23 @@ +# Metadata Extraction Gateway - Sits in front of LiteLLM +FROM python:3.11-slim + +WORKDIR /app + +# Prevent Python from buffering stdout/stderr +ENV PYTHONUNBUFFERED=1 + +# Copy requirements file +COPY ./requirements.txt /app/requirements.txt + +# Install dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the proxy package +COPY ./proxy_core /app/proxy_core + +# Expose port +EXPOSE 4000 + +# Run the gateway as a module +# LITELLM_URL will be set by environment (docker-compose or Cloud Run) +CMD ["python", "-m", "proxy_core.main"] diff --git a/eval_protocol/proxy/README.md b/eval_protocol/proxy/README.md new file mode 100644 index 00000000..ffcdaf25 --- /dev/null +++ b/eval_protocol/proxy/README.md @@ -0,0 +1,372 @@ +# LiteLLM Metadata Extraction Gateway + +A FastAPI-based metadata extraction gateway that sits in front of LiteLLM to inject evaluation metadata into LLM requests and track completions for distributed evaluation workflows. + +## Overview + +The Metadata Gateway is a proxy service that enhances LiteLLM by: +- **Extracting metadata from URL paths** and injecting it as Langfuse tags +- **Managing Langfuse credentials** per-project without exposing them to clients +- **Tracking completion insertion IDs** in Redis for completeness verification +- **Fetching and validating traces** from Langfuse with built-in retry logic + +This enables distributed evaluation systems to track which LLM completions belong to which evaluation runs, ensuring data completeness and proper attribution. + +## Architecture + +``` +┌─────────────┐ +│ Client │ +│ (SDK/CLI) │ +└──────┬──────┘ + │ Authorization: Bearer + │ POST /rollout_id/{id}/invocation_id/{id}/.../chat/completions + ▼ +┌─────────────────────────┐ +│ Metadata Gateway │ +│ (FastAPI Service) │ +│ - Extract metadata │ +│ - Inject Langfuse keys │ +│ - Generate UUID7 IDs │ +└──────┬──────────┬───────┘ + │ │ + ▼ ▼ + ┌────────┐ ┌─────────────┐ + │ Redis │ │ LiteLLM │ + │ │ │ Backend │ + │ Track │ │ │ + │ IDs │ └──────┬──────┘ + └────────┘ │ + ▼ + ┌─────────────┐ + │ Langfuse │ + │ (Tracing) │ + └─────────────┘ +``` + +### Components + +#### 1. **Metadata Gateway** (`proxy_core/`) + - **`app.py`**: Main FastAPI application with route definitions + - **`litellm.py`**: LiteLLM client for forwarding requests + - **`langfuse.py`**: Langfuse trace fetching with retry logic + - **`redis_utils.py`**: Redis operations for insertion ID tracking + - **`models.py`**: Pydantic models for configuration and responses + - **`auth.py`**: Authentication provider interface (extensible) + - **`main.py`**: Entry point for running the service + +#### 2. **Redis** + - Stores insertion IDs per rollout for completeness checking + - Uses Redis Sets: `rollout_id -> {insertion_id_1, insertion_id_2, ...}` + +#### 3. **LiteLLM Backend** + - Standard LiteLLM proxy for routing to LLM providers + - Configured with Langfuse callbacks for automatic tracing + +## Key Features + +### Metadata Injection +URL paths encode evaluation metadata that gets injected as Langfuse tags: +- `rollout_id`: Unique ID for a batch evaluation run +- `invocation_id`: ID for a single invocation within a rollout +- `experiment_id`: Experiment identifier +- `run_id`: Run identifier within an experiment +- `row_id`: Dataset row identifier +- `insertion_id`: Auto-generated UUID7 for this specific completion + +### Completeness Tracking +1. **On chat completion**: Generate UUID7 insertion_id and store in Redis +2. **On trace fetch**: Verify all expected insertion_ids are present in Langfuse +3. **Retry logic**: Automatic retries with exponential backoff for incomplete traces + +### Multi-Project Support +- Store Langfuse credentials for multiple projects in `secrets.yaml` +- Route requests to the correct project via `project_id` in URL or use default +- Credentials never exposed to clients + +## Setup + +### Prerequisites +- Docker and Docker Compose (recommended) +- Python 3.11+ (for local development) + +### Local Development: Docker Compose + +1. **Create secrets file:** + ```bash + cp proxy_core/secrets.yaml.example proxy_core/secrets.yaml + ``` + +2. **Edit `proxy_core/secrets.yaml`** with your Langfuse credentials. +**Important**: use your real Langfuse project ID (e.g. `cmg00asdf0123...`). + ```yaml + langfuse_keys: + my-project: + public_key: pk-lf-... + secret_key: sk-lf-... + default_project_id: my-project + ``` + +3. **Start services:** + ```bash + docker-compose up -d + ``` + +4. **Verify services are running:** + ```bash + curl http://localhost:4000/health + # Expected: {"status":"healthy","service":"metadata-proxy"} + ``` + +The gateway will be available at `http://localhost:4000`. + +## API Reference + +### Chat Completions + +#### With Full Metadata +``` +POST /rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions +POST /project_id/{project_id}/rollout_id/{rollout_id}/.../chat/completions +``` + +**Features:** +- Extracts metadata from URL path +- Generates UUID7 insertion_id +- Injects Langfuse credentials +- Tracks insertion_id in Redis +- Forwards to LiteLLM + +**Request:** +```bash +curl -X POST http://localhost:4000/rollout_id/abc123/invocation_id/inv1/experiment_id/exp1/run_id/run1/row_id/row1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-..." \ + -d '{ + "model": "fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +**Response:** Standard OpenAI chat completion response + +#### With Project Only +``` +POST /project_id/{project_id}/chat/completions +``` + +For completions that don't need rollout tracking. + +#### With Encoded Base URL +``` +POST /rollout_id/{rollout_id}/.../encoded_base_url/{encoded_base_url}/chat/completions +``` + +The `encoded_base_url` is base64-encoded URL string injected into the request body as `base_url`. + +### Trace Fetching + +#### Fetch All Langfuse Traces +``` +GET /traces?tags=rollout_id:abc123 +GET /v1/traces?tags=rollout_id:abc123 +GET /project_id/{project_id}/traces?tags=rollout_id:abc123 +GET /v1/project_id/{project_id}/traces?tags=rollout_id:abc123 +``` + +Waits for all expected insertion_ids to complete before returning all traces. + +#### Fetch Latest Langfuse Trace (Pointwise) +``` +GET /traces/pointwise?tags=rollout_id:abc123 +GET /v1/traces/pointwise?tags=rollout_id:abc123 +GET /project_id/{project_id}/traces/pointwise?tags=rollout_id:abc123 +GET /v1/project_id/{project_id}/traces/pointwise?tags=rollout_id:abc123 +``` + +Returns only the latest trace (UUID v7 time-ordered). Much faster for pointwise evaluations where you only need the final accumulated result. + +**Required Query Parameters:** +- `tags`: Array of tags (must include at least one `rollout_id:*` tag) + +**Optional Query Parameters:** +- `limit`: Max traces to fetch (default: 100) +- `sample_size`: Random sample size if more traces found +- `user_id`, `session_id`, `name`, `environment`, `version`, `release`: Langfuse filters +- `fields`: Comma-separated fields to include +- `hours_back`: Fetch traces from last N hours +- `from_timestamp`, `to_timestamp`: ISO datetime strings for time range +- `sleep_between_gets`: Delay between trace.get calls (default: 2.5s) +- `max_retries`: Retry attempts for incomplete traces (default: 3) + +**Completeness Logic:** +1. Fetches traces from Langfuse matching tags +2. Extracts insertion_ids from trace tags +3. Compares with expected insertion_ids in Redis +4. Retries with exponential backoff if incomplete +5. Returns 404 if still incomplete after max_retries + +**Response:** +```json +{ + "project_id": "my-project", + "total_traces": 42, + "traces": [ + { + "id": "trace-123", + "name": "chat-completion", + "tags": ["rollout_id:abc123", "insertion_id:uuid7..."], + "input": {...}, + "output": {...}, + "observations": [...] + } + ] +} +``` + +### Health Check +``` +GET /health +``` + +Returns service health status. + +### Catch-All Proxy +``` +ANY /{path} +``` + +Forwards any other request to LiteLLM backend with API key injection. + +## Configuration + +### Environment Variables + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| `LITELLM_URL` | Yes | - | URL of LiteLLM backend | +| `REDIS_HOST` | Yes | - | Redis hostname | +| `REDIS_PORT` | No | 6379 | Redis port | +| `REDIS_PASSWORD` | No | - | Redis password | +| `SECRETS_PATH` | No | `proxy_core/secrets.yaml` | Path to secrets file (YAML) | +| `LANGFUSE_HOST` | No | `https://cloud.langfuse.com` | Langfuse base URL | +| `REQUEST_TIMEOUT` | No | 300.0 | Request timeout (LLM calls) in seconds | +| `LOG_LEVEL` | No | INFO | Logging level | +| `PORT` | No | 4000 | Gateway port | + +### Secrets Configuration + +Create `proxy_core/secrets.yaml`: +```yaml +langfuse_keys: + project-1: + public_key: pk-lf-... + secret_key: sk-lf-... + project-2: + public_key: pk-lf-... + secret_key: sk-lf-... +default_project_id: project-1 +``` + +**Security:** `secrets.yaml` is ignored via `.gitignore`. + +### LiteLLM Configuration + +The `config_no_cache.yaml` configures LiteLLM: +```yaml +model_list: + - model_name: "*" + litellm_params: + model: "*" +litellm_settings: + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + drop_params: True +general_settings: + allow_client_side_credentials: true +``` + +Key settings: +- **Wildcard model support**: Route any model to any provider +- **Langfuse callbacks**: Automatic tracing on success/failure +- **Client-side credentials**: Accept API keys from request body + +## Security Considerations + +### Authentication +- **Default**: No authentication (`NoAuthProvider`) +- **Extensible**: Implement custom `AuthProvider` for production +- **API Keys**: Client API keys forwarded to LiteLLM, never stored + +### Trace Fetching Security +- **Required rollout_id tag**: Prevents fetching all traces +- **Project isolation**: Projects can only access their own Langfuse data +- **Optional auth**: `/traces` endpoint can require authentication + +### Best Practices +1. **Never commit `secrets.json`** - use environment variables in production +2. **Use HTTPS** in production deployments +3. **Implement proper authentication** for production use +4. **Rotate Langfuse keys** regularly +5. **Monitor Redis memory** usage for large rollouts + +## Deployment + +### Docker Compose (Development) +```bash +docker-compose up -d +``` + +### Kubernetes +Create deployment with: +- Secrets for `secrets.json` and Redis credentials +- Service for internal/external access +- ConfigMap for LiteLLM config +- Redis StatefulSet or managed Redis service + +## Development + +### Project Structure +``` +eval_protocol/proxy/ +├── proxy_core/ # Main application package +│ ├── __init__.py +│ ├── app.py # FastAPI routes +│ ├── litellm.py # LiteLLM client +│ ├── langfuse.py # Langfuse integration +│ ├── redis_utils.py # Redis operations +│ ├── models.py # Pydantic models +│ ├── auth.py # Authentication +│ ├── main.py # Entry point +│ └── secrets.yaml.example +├── docker-compose.yml # Local development stack +├── Dockerfile.gateway # Gateway container +├── config_no_cache.yaml # LiteLLM config +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +### Testing + +#### Test chat completion: +```bash +curl -X POST http://localhost:4000/rollout_id/test123/invocation_id/inv1/experiment_id/exp1/run_id/run1/row_id/row1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $FIREWORKS_API_KEY" \ + -d '{ + "model": "fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct", + "messages": [{"role": "user", "content": "Say hello"}] + }' +``` + +#### Test trace fetching: +```bash +curl "http://localhost:4000/traces?tags=rollout_id:test123" \ + -H "Authorization: Bearer your-auth-token" +``` + +#### Check Redis: +```bash +redis-cli +> SMEMBERS test123 # View insertion_ids for rollout +``` diff --git a/eval_protocol/proxy/__init__.py b/eval_protocol/proxy/__init__.py new file mode 100644 index 00000000..c471064c --- /dev/null +++ b/eval_protocol/proxy/__init__.py @@ -0,0 +1,18 @@ +""" +LiteLLM Metadata Extraction Gateway + +A proxy service for extracting evaluation metadata from URL paths and managing +Langfuse tracing for distributed evaluation workflows. +""" + +from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig, ChatParams, TracesParams, AccountInfo + +__all__ = [ + "create_app", + "AuthProvider", + "NoAuthProvider", + "ProxyConfig", + "ChatParams", + "TracesParams", + "AccountInfo", +] diff --git a/eval_protocol/proxy/config_no_cache.yaml b/eval_protocol/proxy/config_no_cache.yaml new file mode 100644 index 00000000..7adb5a72 --- /dev/null +++ b/eval_protocol/proxy/config_no_cache.yaml @@ -0,0 +1,10 @@ +model_list: + - model_name: "*" + litellm_params: + model: "*" +litellm_settings: + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + drop_params: True +general_settings: + allow_client_side_credentials: true diff --git a/eval_protocol/proxy/docker-compose.yml b/eval_protocol/proxy/docker-compose.yml new file mode 100644 index 00000000..a6058e0e --- /dev/null +++ b/eval_protocol/proxy/docker-compose.yml @@ -0,0 +1,70 @@ +services: + # Redis - For tracking assistant message counts + redis: + image: redis:7-alpine + platform: linux/amd64 + container_name: proxy-redis + ports: + - "6379:6379" # Expose for debugging if needed + networks: + - litellm-network + restart: unless-stopped + command: redis-server --appendonly yes + volumes: + - redis-data:/data + + # LiteLLM Backend - Handles actual LLM proxying + litellm-backend: + image: litellm/litellm:v1.77.3-stable + platform: linux/amd64 + container_name: litellm-backend + command: ["--config", "/app/config.yaml", "--port", "4000", "--host", "0.0.0.0"] + # If you want to be able to use other model providers like OpenAI, Anthropic, etc., you need to set keys in .env file. + env_file: + - .env # Load API keys from .env file + environment: + - LANGFUSE_PUBLIC_KEY=dummy # Set dummy public and private key so Langfuse instance initializes in LiteLLM, then real keys get sent in proxy + - LANGFUSE_SECRET_KEY=dummy + volumes: + - ./config_no_cache.yaml:/app/config.yaml:ro + ports: + - "4001:4000" # Expose on 4001 for direct access if needed + networks: + - litellm-network + restart: unless-stopped + + # Metadata Gateway - Public-facing service that extracts metadata from URLs + metadata-gateway: + build: + context: . + dockerfile: Dockerfile.gateway + container_name: metadata-gateway + environment: + # Point to the LiteLLM backend service + - LITELLM_URL=http://litellm-backend:4000 + - PORT=4000 + # Redis configuration for assistant message counting + - REDIS_HOST=redis + - REDIS_PORT=6379 + # No password for local Redis + - REQUEST_TIMEOUT=300 + # Logging level: INFO (default) + - LOG_LEVEL=INFO + # Langfuse and secrets + - SECRETS_PATH=/app/proxy_core/secrets.yaml + - LANGFUSE_HOST=${LANGFUSE_HOST:-https://cloud.langfuse.com} + ports: + - "4000:4000" # Main public-facing port + networks: + - litellm-network + depends_on: + - litellm-backend + - redis + restart: unless-stopped + +networks: + litellm-network: + driver: bridge + +volumes: + redis-data: diff --git a/eval_protocol/proxy/proxy_core/__init__.py b/eval_protocol/proxy/proxy_core/__init__.py new file mode 100644 index 00000000..053f922f --- /dev/null +++ b/eval_protocol/proxy/proxy_core/__init__.py @@ -0,0 +1,13 @@ +from .models import ProxyConfig, ChatParams, TracesParams, AccountInfo +from .auth import AuthProvider, NoAuthProvider +from .app import create_app + +__all__ = [ + "ProxyConfig", + "ChatParams", + "TracesParams", + "AccountInfo", + "create_app", + "AuthProvider", + "NoAuthProvider", +] diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py new file mode 100644 index 00000000..528d467e --- /dev/null +++ b/eval_protocol/proxy/proxy_core/app.py @@ -0,0 +1,305 @@ +""" +Metadata Extraction Gateway +A FastAPI service that sits in front of LiteLLM and extracts metadata from URL paths. +""" + +from fastapi import FastAPI, Depends, Request, Query +from typing import Optional, List +import os +import redis +import logging +import yaml +from pathlib import Path +import sys +from contextlib import asynccontextmanager + +from .models import ProxyConfig, LangfuseTracesResponse, TracesParams, ChatParams, ChatRequestHook, TracesRequestHook +from .auth import AuthProvider, NoAuthProvider +from .litellm import handle_chat_completion, proxy_to_litellm +from .langfuse import fetch_langfuse_traces, pointwise_fetch_langfuse_trace + +# Configure logging before any other imports (so all modules inherit this config) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +if not logging.getLogger().hasHandlers(): + logging.basicConfig( + level=getattr(logging, log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + +logger = logging.getLogger(__name__) + + +def build_proxy_config( + preprocess_chat_request: Optional[ChatRequestHook] = None, + preprocess_traces_request: Optional[TracesRequestHook] = None, +) -> ProxyConfig: + """Load environment and secrets, and build ProxyConfig""" + # Env + litellm_url = os.getenv("LITELLM_URL") + if not litellm_url: + raise ValueError("LITELLM_URL environment variable must be set") + request_timeout = float(os.getenv("REQUEST_TIMEOUT", "300.0")) + langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") + + # Secrets - use SECRETS_PATH env var if set, otherwise default to proxy/secrets.yaml + secrets_path_str = os.getenv("SECRETS_PATH") + if secrets_path_str: + secrets_path = Path(secrets_path_str) + else: + secrets_path = Path(__file__).parent / "secrets.yaml" + if not secrets_path.exists(): + raise ValueError( + "Secrets file not found! Please create it from secrets.yaml.example:\n" + " cp eval_protocol/proxy/proxy_core/secrets.yaml.example eval_protocol/proxy/proxy_core/secrets.yaml\n" + "Then add your Langfuse API keys to the secrets file" + ) + try: + with open(secrets_path, "r") as f: + secrets_config = yaml.safe_load(f) + langfuse_keys = secrets_config["langfuse_keys"] + default_project_id = secrets_config["default_project_id"] + logger.info(f"Loaded {len(langfuse_keys)} Langfuse project(s) from {secrets_path.name}") + except KeyError as e: + raise ValueError(f"Missing required key in secrets file: {e}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid format in secrets file {secrets_path.name}: {e}") + + return ProxyConfig( + litellm_url=litellm_url, + request_timeout=request_timeout, + langfuse_host=langfuse_host, + langfuse_keys=langfuse_keys, + default_project_id=default_project_id, + preprocess_chat_request=preprocess_chat_request, + preprocess_traces_request=preprocess_traces_request, + ) + + +def init_redis() -> redis.Redis: + """Initialize and return a Redis client from environment variables.""" + redis_host = os.getenv("REDIS_HOST") + if not redis_host: + raise ValueError("REDIS_HOST environment variable must be set") + redis_port = int(os.getenv("REDIS_PORT", "6379")) + redis_password = os.getenv("REDIS_PASSWORD") + + try: + client = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password if redis_password else None, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True, + ) + client.ping() + logger.info(f"Connected to Redis at {redis_host}:{redis_port}") + return client + except Exception as e: + raise ConnectionError(f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}") + + +def create_app( + auth_provider: AuthProvider = NoAuthProvider(), + preprocess_chat_request: Optional[ChatRequestHook] = None, + preprocess_traces_request: Optional[TracesRequestHook] = None, +) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + # Build runtime on startup + app.state.config = build_proxy_config(preprocess_chat_request, preprocess_traces_request) + app.state.redis = init_redis() + + try: + yield + finally: + try: + app.state.redis.close() + except Exception: + pass + + app = FastAPI(title="LiteLLM Metadata Proxy", lifespan=lifespan) + + def get_config(request: Request) -> ProxyConfig: + return request.app.state.config + + def get_redis(request: Request) -> redis.Redis: + return request.app.state.redis + + def get_traces_params( + tags: Optional[List[str]] = Query(default=None), + project_id: Optional[str] = None, + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[str] = None, + to_timestamp: Optional[str] = None, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + ) -> TracesParams: + return TracesParams( + tags=tags, + project_id=project_id, + limit=limit, + sample_size=sample_size, + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + hours_back=hours_back, + from_timestamp=from_timestamp, + to_timestamp=to_timestamp, + sleep_between_gets=sleep_between_gets, + max_retries=max_retries, + ) + + async def require_auth(request: Request) -> None: + account_info = auth_provider.validate_and_return_account_info(request) + request.state.account_id = account_info.account_id if account_info else None + return None + + # ===================== + # Chat completion routes + # ===================== + @app.post( + "/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/v1/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/v1/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/v1/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/v1/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + async def chat_completion_with_full_metadata( + rollout_id: str, + invocation_id: str, + experiment_id: str, + run_id: str, + row_id: str, + request: Request, + project_id: Optional[str] = None, + encoded_base_url: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ): + params = ChatParams( + project_id=project_id, + rollout_id=rollout_id, + invocation_id=invocation_id, + experiment_id=experiment_id, + run_id=run_id, + row_id=row_id, + encoded_base_url=encoded_base_url, + ) + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + @app.post("/project_id/{project_id}/chat/completions") + @app.post("/v1/project_id/{project_id}/chat/completions") + async def chat_completion_with_project_only( + project_id: str, + request: Request, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ): + params = ChatParams(project_id=project_id) + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + # =============== + # Traces routes + # =============== + @app.get("/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/traces", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) + async def get_langfuse_traces( + request: Request, + params: TracesParams = Depends(get_traces_params), + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + if project_id is not None: + params.project_id = project_id + return await fetch_langfuse_traces( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + @app.get("/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/traces/pointwise", response_model=LangfuseTracesResponse) + async def pointwise_get_langfuse_trace( + request: Request, + params: TracesParams = Depends(get_traces_params), + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + if project_id is not None: + params.project_id = project_id + return await pointwise_fetch_langfuse_trace( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + + # Health + @app.get("/health") + async def health(): + return {"status": "healthy", "service": "metadata-proxy"} + + # Catch-all + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def catch_all_proxy( + path: str, + request: Request, + config: ProxyConfig = Depends(get_config), + ): + return await proxy_to_litellm(config, path, request) + + return app diff --git a/eval_protocol/proxy/proxy_core/auth.py b/eval_protocol/proxy/proxy_core/auth.py new file mode 100644 index 00000000..cdbf6d3c --- /dev/null +++ b/eval_protocol/proxy/proxy_core/auth.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +import logging +from fastapi import Request +from typing import Optional +from .models import AccountInfo + +logger = logging.getLogger(__name__) + + +class AuthProvider(ABC): + @abstractmethod + def validate_and_return_account_info(self, request: Request) -> Optional[AccountInfo]: ... + + +class NoAuthProvider(AuthProvider): + def validate_and_return_account_info(self, request: Request) -> Optional[AccountInfo]: + return None diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py new file mode 100644 index 00000000..9764e3c9 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -0,0 +1,528 @@ +""" +Traces fetching handler for Langfuse integration. +""" + +import time +import random +import logging +import asyncio +from typing import List, Optional, Dict, Any, Set +from datetime import datetime, timedelta +from fastapi import HTTPException, Request +import redis +from .redis_utils import get_insertion_ids +from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse, TracesParams + +logger = logging.getLogger(__name__) + + +def _extract_tag_value(tags: Optional[List[str]], prefix: str) -> Optional[str]: + """Extract value from a tag with the given prefix (e.g., 'rollout_id:' or 'insertion_id:').""" + if not tags: + return None + for tag in tags: + if tag.startswith(prefix): + return tag.split(":", 1)[1] + return None + + +def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: + """Convert Langfuse trace object to dict format.""" + timestamp = getattr(trace_full, "timestamp", None) + + return { + "id": trace_full.id, + "name": getattr(trace_full, "name", None), + "user_id": getattr(trace_full, "user_id", None), + "session_id": getattr(trace_full, "session_id", None), + "tags": getattr(trace_full, "tags", []), + "timestamp": str(timestamp) if timestamp else None, + "input": getattr(trace_full, "input", None), + "output": getattr(trace_full, "output", None), + "metadata": getattr(trace_full, "metadata", None), + "observations": [ + { + "id": obs.id, + "type": getattr(obs, "type", None), + "name": getattr(obs, "name", None), + "start_time": str(getattr(obs, "start_time", None)) if getattr(obs, "start_time", None) else None, + "end_time": str(getattr(obs, "end_time", None)) if getattr(obs, "end_time", None) else None, + "input": getattr(obs, "input", None), + "output": getattr(obs, "output", None), + "parent_observation_id": getattr(obs, "parent_observation_id", None), + } + for obs in getattr(trace_full, "observations", []) + ] + if hasattr(trace_full, "observations") + else [], + } + + +async def _fetch_trace_list_with_retry( + langfuse_client: Any, + page: int, + limit: int, + tags: Optional[List[str]], + user_id: Optional[str], + session_id: Optional[str], + name: Optional[str], + environment: Optional[str], + version: Optional[str], + release: Optional[str], + fields: Optional[str], + from_ts: Optional[datetime], + to_ts: Optional[datetime], + max_retries: int, +) -> Any: + """Fetch trace list with rate limit retry logic.""" + list_retries = 0 + while list_retries < max_retries: + try: + traces = langfuse_client.api.trace.list( + page=page, + limit=limit, + tags=tags, + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + from_timestamp=from_ts, + to_timestamp=to_ts, + order_by="timestamp.desc", + ) + + # If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse + if traces and traces.meta and traces.meta.total_items == 0 and page == 1: + raise Exception("Empty results") + + return traces + except Exception as e: + list_retries += 1 + if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)): + sleep_time = 2**list_retries # Exponential backoff for rate limits + logger.warning( + "Retrying trace.list in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e) + ) + await asyncio.sleep(sleep_time) + elif list_retries == max_retries: + # Return 404 if we've retried max_retries + # TODO: write some tests around proxy exception handling + logger.error("Failed to fetch trace list after %d retries: %s", max_retries, e) + raise HTTPException( + status_code=404, detail=f"Failed to fetch traces after {max_retries} retries: {str(e)}" + ) + else: + # Catch all other exceptions + logger.error("Failed to fetch trace list: %s", e) + raise HTTPException(status_code=500, detail=f"Failed to fetch traces: {str(e)}") + + +async def _fetch_trace_detail_with_retry( + langfuse_client: Any, + trace_id: str, + max_retries: int, +) -> Optional[Any]: + """Fetch full trace details with rate limit retry logic.""" + detail_retries = 0 + while detail_retries < max_retries: + try: + trace_full = langfuse_client.api.trace.get(trace_id) + return trace_full + except Exception as e: + detail_retries += 1 + if "429" in str(e) and detail_retries < max_retries: + sleep_time = 2**detail_retries # Exponential backoff for rate limits + logger.warning( + "Rate limit hit on trace.get(%s), retrying in %ds (attempt %d/%d)", + trace_id, + sleep_time, + detail_retries, + max_retries, + ) + await asyncio.sleep(sleep_time) + elif "Not Found" in str(e) or "404" in str(e): + logger.debug("Trace %s not found, skipping", trace_id) + return None + else: + logger.warning("Failed to fetch trace %s after %d retries: %s", trace_id, max_retries, e) + return None + + +async def fetch_langfuse_traces( + config: ProxyConfig, + redis_client: redis.Redis, + request: Request, + params: TracesParams, +): + """ + Fetch full traces from Langfuse for the specified project. + + This endpoint uses the stored Langfuse keys for the project and polls + traces based on the provided filters. + + If project_id is not provided, uses the default project. + + Returns a list of full trace objects (including observations) in JSON format. + """ + + # Preprocess traces request + if config.preprocess_traces_request: + params = config.preprocess_traces_request(request, params) + + tags = params.tags + project_id = params.project_id + limit = params.limit + sample_size = params.sample_size + user_id = params.user_id + session_id = params.session_id + name = params.name + environment = params.environment + version = params.version + release = params.release + fields = params.fields + hours_back = params.hours_back + from_timestamp = params.from_timestamp + to_timestamp = params.to_timestamp + sleep_between_gets = params.sleep_between_gets + max_retries = params.max_retries + + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Validate project_id + if project_id not in config.langfuse_keys: + raise HTTPException( + status_code=404, + detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}", + ) + + # Extract rollout_id from tags for Redis lookup + rollout_id = _extract_tag_value(tags, "rollout_id:") + + try: + # Import the Langfuse adapter + from langfuse import Langfuse + + # Create Langfuse client with the project's keys + langfuse_client = Langfuse( + public_key=config.langfuse_keys[project_id]["public_key"], + secret_key=config.langfuse_keys[project_id]["secret_key"], + host=config.langfuse_host, + ) + + # Parse datetime strings if provided + from_ts = None + to_ts = None + if from_timestamp: + from_ts = datetime.fromisoformat(from_timestamp.replace("Z", "+00:00")) + if to_timestamp: + to_ts = datetime.fromisoformat(to_timestamp.replace("Z", "+00:00")) + + # Determine time window: explicit from/to takes precedence over hours_back + if from_ts is None and to_ts is None and hours_back: + to_ts = datetime.now() + from_ts = to_ts - timedelta(hours=hours_back) + + # Get expected insertion_ids from Redis for completeness checking + expected_ids: Set[str] = set() + if rollout_id: + expected_ids = get_insertion_ids(redis_client, rollout_id) + logger.info(f"Fetching traces for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids") + if not expected_ids: + logger.warning( + f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces." + ) + raise HTTPException( + status_code=500, + detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces.", + ) + + # Track all traces we've collected across retry attempts + trace_ids: Set[str] = set() # Langfuse trace IDs (for deduplication) + all_traces: List[Dict[str, Any]] = [] # Full trace data + insertion_ids: Set[str] = set() # Insertion IDs extracted from traces (for completeness check) + + for retry in range(max_retries): + # On first attempt, use rollout_id tag. On retries, target missing insertion_ids + if retry == 0: + fetch_tags = tags + else: + # Build targeted tags for missing insertion_ids + missing_ids = expected_ids - insertion_ids + fetch_tags = [f"insertion_id:{id}" for id in missing_ids] + logger.info( + f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" + ) + + current_page = 1 + collected = 0 + + while collected < limit: + current_page_limit = min(100, limit - collected) # Langfuse API max is 100 + + # Fetch trace list with rate limit retry logic + traces = await _fetch_trace_list_with_retry( + langfuse_client, + current_page, + current_page_limit, + fetch_tags, + user_id, + session_id, + name, + environment, + version, + release, + fields, + from_ts, + to_ts, + max_retries, + ) + + if not traces or not traces.data: + logger.debug("No more traces found on page %d", current_page) + break + + # For traces we find not in our current list of traces, do trace.get + for trace_info in traces.data: + if trace_info.id in trace_ids: + continue # Skip already processed traces + + if sleep_between_gets > 0: + await asyncio.sleep(sleep_between_gets) + + # Fetch full trace with rate limit retry logic + trace_full = await _fetch_trace_detail_with_retry( + langfuse_client, + trace_info.id, + max_retries, + ) + + if trace_full: + try: + trace_dict = _serialize_trace_to_dict(trace_full) + all_traces.append(trace_dict) + trace_ids.add(trace_info.id) + + # Extract insertion_id for completeness checking + insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:") + if insertion_id: + insertion_ids.add(insertion_id) + logger.debug(f"Found insertion_id '{insertion_id}' for rollout '{rollout_id}'") + + except Exception as e: + logger.warning("Failed to serialize trace %s: %s", trace_info.id, e) + continue + + collected += len(traces.data) + + # Check if we have more pages + if hasattr(traces.meta, "page") and hasattr(traces.meta, "total_pages"): + if traces.meta.page >= traces.meta.total_pages: + break + elif len(traces.data) < current_page_limit: + break + + current_page += 1 + + # If we have all expected completions or more, return traces. At least once is ok. + if expected_ids <= insertion_ids: + logger.info( + f"Traces complete for rollout '{rollout_id}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces" + ) + if sample_size is not None and len(all_traces) > sample_size: + all_traces = random.sample(all_traces, sample_size) + logger.info(f"Sampled down to {sample_size} traces") + + return LangfuseTracesResponse( + project_id=project_id, + total_traces=len(all_traces), + traces=[TraceResponse(**trace) for trace in all_traces], + ) + + # If it doesn't match, wait and do loop again (exponential backoff) + if retry < max_retries - 1: + wait_time = 2**retry + still_missing = expected_ids - insertion_ids + logger.info( + f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." + ) + await asyncio.sleep(wait_time) + + logger.error( + f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions." + ) + raise HTTPException( + status_code=404, + detail=f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions.", + ) + + except ImportError: + raise HTTPException(status_code=500, detail="Langfuse SDK not installed. Install with: pip install langfuse") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching traces from Langfuse: {str(e)}") + + +async def pointwise_fetch_langfuse_trace( + config: ProxyConfig, + redis_client: redis.Redis, + request: Request, + params: TracesParams, +): + """ + Fetch the latest trace from Langfuse for the specified project. + + Since insertion_ids are UUID v7 (time-ordered), we only fetch the last one + as it contains all accumulated information from the pointwise evaluation. + + Returns a single trace object or raises if not found. + """ + + # Preprocess traces request + if config.preprocess_traces_request: + params = config.preprocess_traces_request(request, params) + + tags = params.tags + project_id = params.project_id + user_id = params.user_id + session_id = params.session_id + name = params.name + environment = params.environment + version = params.version + release = params.release + fields = params.fields + hours_back = params.hours_back + from_timestamp = params.from_timestamp + to_timestamp = params.to_timestamp + sleep_between_gets = params.sleep_between_gets + max_retries = params.max_retries + + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Validate project_id + if project_id not in config.langfuse_keys: + raise HTTPException( + status_code=404, + detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}", + ) + + # Extract rollout_id from tags for Redis lookup + rollout_id = _extract_tag_value(tags, "rollout_id:") + + try: + # Import the Langfuse adapter + from langfuse import Langfuse + + # Create Langfuse client with the project's keys + logger.debug(f"Connecting to Langfuse at {config.langfuse_host} for project '{project_id}'") + langfuse_client = Langfuse( + public_key=config.langfuse_keys[project_id]["public_key"], + secret_key=config.langfuse_keys[project_id]["secret_key"], + host=config.langfuse_host, + ) + + # Parse datetime strings if provided + from_ts = None + to_ts = None + if from_timestamp: + from_ts = datetime.fromisoformat(from_timestamp.replace("Z", "+00:00")) + if to_timestamp: + to_ts = datetime.fromisoformat(to_timestamp.replace("Z", "+00:00")) + + # Determine time window: explicit from/to takes precedence over hours_back + if from_ts is None and to_ts is None and hours_back: + to_ts = datetime.now() + from_ts = to_ts - timedelta(hours=hours_back) + + # Get insertion_ids from Redis to find the latest one + expected_ids: Set[str] = set() + if rollout_id: + expected_ids = get_insertion_ids(redis_client, rollout_id) + logger.info( + f"Pointwise fetch for rollout_id '{rollout_id}', found {len(expected_ids)} insertion_ids in Redis" + ) + if not expected_ids: + logger.warning( + f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace." + ) + raise HTTPException( + status_code=500, + detail=f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace.", + ) + + # Get the latest (last) insertion_id since UUID v7 is time-ordered + latest_insertion_id = max(expected_ids) # UUID v7 max = newest + logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for rollout '{rollout_id}'") + + for retry in range(max_retries): + # Fetch trace list targeting the latest insertion_id + traces = await _fetch_trace_list_with_retry( + langfuse_client, + page=1, + limit=1, # Only need the one trace + tags=[f"insertion_id:{latest_insertion_id}"], + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + from_ts=from_ts, + to_ts=to_ts, + max_retries=max_retries, + ) + + if traces and traces.data: + # Get the trace info + trace_info = traces.data[0] + logger.debug(f"Found trace {trace_info.id} for latest insertion_id {latest_insertion_id}") + + # Fetch full trace details + trace_full = await _fetch_trace_detail_with_retry( + langfuse_client, + trace_info.id, + max_retries, + ) + + if trace_full: + trace_dict = _serialize_trace_to_dict(trace_full) + logger.info( + f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id}" + ) + return LangfuseTracesResponse( + project_id=project_id, + total_traces=1, + traces=[TraceResponse(**trace_dict)], + ) + + # If not successful and not last retry, sleep and continue + if retry < max_retries - 1: + wait_time = 2**retry + logger.info( + f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + + # After all retries failed + logger.error( + f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id} after {max_retries} retries" + ) + raise HTTPException( + status_code=404, + detail=f"Failed to fetch latest trace for rollout '{rollout_id}' after {max_retries} retries", + ) + + except ImportError: + raise HTTPException(status_code=500, detail="Langfuse SDK not installed. Install with: pip install langfuse") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching latest trace from Langfuse: {str(e)}") diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py new file mode 100644 index 00000000..cdd2383b --- /dev/null +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -0,0 +1,173 @@ +""" +LiteLLM client - handles all communication with LiteLLM service. +""" + +import json +import base64 +import httpx +import logging +from uuid6 import uuid7 +from fastapi import Request, Response, HTTPException +import redis +from .redis_utils import register_insertion_id +from .models import ProxyConfig, ChatParams + +logger = logging.getLogger(__name__) + + +async def handle_chat_completion( + config: ProxyConfig, + redis_client: redis.Redis, + request: Request, + params: ChatParams, +) -> Response: + """ + Handle chat completion requests and forward to LiteLLM. + + If metadata IDs (rollout_id, etc.) are provided, they'll be added as tags + and the assistant message count will be tracked in Redis. + + If encoded_base_url is provided, it will be decoded and added to the request. + """ + body = await request.body() + data = json.loads(body) if body else {} + + if config.preprocess_chat_request: + data, params = config.preprocess_chat_request(data, request, params) + + project_id = params.project_id + rollout_id = params.rollout_id + invocation_id = params.invocation_id + experiment_id = params.experiment_id + run_id = params.run_id + row_id = params.row_id + encoded_base_url = params.encoded_base_url + + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Decode and add base_url if provided + if encoded_base_url: + try: + # Decode from URL-safe base64 + decoded_bytes = base64.urlsafe_b64decode(encoded_base_url) + base_url = decoded_bytes.decode("utf-8") + data["base_url"] = base_url + logger.debug(f"Decoded base_url: {base_url}") + except Exception as e: + logger.error(f"Failed to decode base_url: {e}") + raise HTTPException(status_code=400, detail=f"Invalid encoded_base_url: {str(e)}") + + # Extract API key from Authorization header and inject into request body + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + # Only inject API key if model is a Fireworks model + model = data.get("model") + if model and isinstance(model, str) and model.startswith("fireworks_ai"): + data["api_key"] = api_key + + # If metadata IDs are provided, add them as tags + insertion_id = None + if rollout_id is not None: + insertion_id = str(uuid7()) + + if "metadata" not in data: + data["metadata"] = {} + if "tags" not in data["metadata"]: + data["metadata"]["tags"] = [] + + # Add extracted IDs as tags + data["metadata"]["tags"].extend( + [ + f"rollout_id:{rollout_id}", + f"insertion_id:{insertion_id}", + f"invocation_id:{invocation_id}", + f"experiment_id:{experiment_id}", + f"run_id:{run_id}", + f"row_id:{row_id}", + ] + ) + + # Add Langfuse configuration + data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"] + data["langfuse_secret_key"] = config.langfuse_keys[project_id]["secret_key"] + data["langfuse_host"] = config.langfuse_host + + # Forward to LiteLLM's standard /chat/completions endpoint + # Set longer timeout for LLM API calls (LLMs can be slow) + timeout = httpx.Timeout(config.request_timeout) + async with httpx.AsyncClient(timeout=timeout) as client: + # Copy headers from original request but exclude content-length (httpx will set it correctly) + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("content-length", None) # Let httpx calculate the correct length + headers["content-type"] = "application/json" + + # Forward to LiteLLM + litellm_url = f"{config.litellm_url}/chat/completions" + + response = await client.post( + litellm_url, + json=data, # httpx will serialize and set correct Content-Length + headers=headers, + ) + + # Register insertion_id in Redis only on successful response + if response.status_code == 200 and insertion_id is not None and rollout_id is not None: + register_insertion_id(redis_client, rollout_id, insertion_id) + + # Return the response + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + +async def proxy_to_litellm(config: ProxyConfig, path: str, request: Request) -> Response: + """ + Catch-all proxy: Forward any request to LiteLLM, extracting API key from Authorization header. + """ + # Set longer timeout for LLM API calls (LLMs can be slow) + timeout = httpx.Timeout(config.request_timeout) + async with httpx.AsyncClient(timeout=timeout) as client: + # Copy headers + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("content-length", None) + + # Get body + body = await request.body() + + # Pass through API key from Authorization header + if request.method in ["POST", "PUT", "PATCH"] and body: + try: + data = json.loads(body) + + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + data["api_key"] = api_key + + # Re-serialize + body = json.dumps(data).encode() + except json.JSONDecodeError: + pass + + # Forward to LiteLLM + litellm_url = f"{config.litellm_url}/{path}" + + response = await client.request( + method=request.method, + url=litellm_url, + headers=headers, + content=body, + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) diff --git a/eval_protocol/proxy/proxy_core/main.py b/eval_protocol/proxy/proxy_core/main.py new file mode 100644 index 00000000..93b6dcb4 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/main.py @@ -0,0 +1,10 @@ +import os +from .app import create_app + +if __name__ == "__main__": + import uvicorn + + # Build app with default NoAuth for local runs + application = create_app() + port = int(os.getenv("PORT", "4000")) + uvicorn.run(application, host="0.0.0.0", port=port) diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py new file mode 100644 index 00000000..f3b5e614 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/models.py @@ -0,0 +1,98 @@ +""" +Models for the LiteLLM Metadata Proxy. +""" + +from pydantic import BaseModel +from typing import Optional, List, Any, Dict, TypeAlias, Callable +from fastapi import Request + + +ChatRequestHook: TypeAlias = Callable[[Dict[str, Any], Request, "ChatParams"], tuple[Dict[str, Any], "ChatParams"]] +TracesRequestHook: TypeAlias = Callable[[Request, "TracesParams"], "TracesParams"] + + +class AccountInfo(BaseModel): + """Account information returned from authentication.""" + + account_id: str + + +class ChatParams(BaseModel): + """Typed container for chat completion URL path parameters.""" + + project_id: Optional[str] = None + rollout_id: Optional[str] = None + invocation_id: Optional[str] = None + experiment_id: Optional[str] = None + run_id: Optional[str] = None + row_id: Optional[str] = None + encoded_base_url: Optional[str] = None + + +class TracesParams(BaseModel): + """Typed container for traces query parameters and controls.""" + + tags: Optional[List[str]] = None + project_id: Optional[str] = None + limit: int = 100 + sample_size: Optional[int] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + name: Optional[str] = None + environment: Optional[str] = None + version: Optional[str] = None + release: Optional[str] = None + fields: Optional[str] = None + hours_back: Optional[int] = None + from_timestamp: Optional[str] = None + to_timestamp: Optional[str] = None + sleep_between_gets: float = 2.5 + max_retries: int = 3 + + +class ProxyConfig(BaseModel): + """Configuration model for the LiteLLM Metadata Proxy""" + + litellm_url: str + request_timeout: float = 300.0 + langfuse_host: str + langfuse_keys: Dict[str, Dict[str, str]] + default_project_id: str + preprocess_chat_request: Optional[ChatRequestHook] = None + preprocess_traces_request: Optional[TracesRequestHook] = None + + +class ObservationResponse(BaseModel): + """Response model for a single observation within a trace""" + + id: str + type: Optional[str] = None + name: Optional[str] = None + start_time: Optional[str] = None + end_time: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + parent_observation_id: Optional[str] = None + + +class TraceResponse(BaseModel): + """Response model for a single trace""" + + id: str + name: Optional[str] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + tags: List[str] = [] + timestamp: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + metadata: Optional[Any] = None + observations: List[ObservationResponse] = [] + + +class LangfuseTracesResponse(BaseModel): + """Response model for the /traces endpoint""" + + project_id: str + total_traces: int + traces: List[TraceResponse] diff --git a/eval_protocol/proxy/proxy_core/redis_utils.py b/eval_protocol/proxy/proxy_core/redis_utils.py new file mode 100644 index 00000000..fa24c38c --- /dev/null +++ b/eval_protocol/proxy/proxy_core/redis_utils.py @@ -0,0 +1,48 @@ +""" +Redis utilities for tracking chat completions via insertion IDs. +""" + +import logging +from typing import Set +import redis + +logger = logging.getLogger(__name__) + + +def register_insertion_id(redis_client: redis.Redis, rollout_id: str, insertion_id: str) -> bool: + """Register an insertion_id for a rollout_id in Redis. + + Tracks all expected completion insertion_ids for this rollout. + + Args: + rollout_id: The rollout ID + insertion_id: Unique identifier for this specific completion + + Returns: + True if successful, False otherwise + """ + try: + redis_client.sadd(rollout_id, insertion_id) + logger.info(f"Registered insertion_id {insertion_id} for rollout {rollout_id}") + return True + except Exception as e: + logger.error(f"Failed to register insertion_id for {rollout_id}: {e}") + return False + + +def get_insertion_ids(redis_client: redis.Redis, rollout_id: str) -> Set[str]: + """Get all expected insertion_ids for a rollout_id from Redis. + + Args: + rollout_id: The rollout ID to get insertion_ids for + + Returns: + Set of insertion_id strings, empty set if none found or on error + """ + try: + insertion_ids = redis_client.smembers(rollout_id) + logger.debug(f"Found {len(insertion_ids)} expected insertion_ids for rollout {rollout_id}") + return insertion_ids + except Exception as e: + logger.error(f"Failed to get insertion_ids for {rollout_id}: {e}") + return set() diff --git a/eval_protocol/proxy/proxy_core/secrets.yaml.example b/eval_protocol/proxy/proxy_core/secrets.yaml.example new file mode 100644 index 00000000..e010d657 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/secrets.yaml.example @@ -0,0 +1,14 @@ +langfuse_keys: + project_1_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_2_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_3_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_4_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx +default_project_id: project_1_id diff --git a/eval_protocol/proxy/requirements.txt b/eval_protocol/proxy/requirements.txt new file mode 100644 index 00000000..15d21d0b --- /dev/null +++ b/eval_protocol/proxy/requirements.txt @@ -0,0 +1,7 @@ +fastapi>=0.116.1 +uvicorn>=0.24.0 +httpx>=0.25.0 +redis>=5.0.0 +langfuse>=2.0.0 +uuid6>=2025.0.0 +PyYAML>=6.0.0 diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 3050b1f5..d27ace02 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: