From 0388ce26c9c73ff742e5ddba62d2fc47f8c9b877 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 9 Oct 2025 01:16:19 -0700 Subject: [PATCH 01/15] Open source the proxy --- .gitignore | 3 + eval_protocol/adapters/fireworks_tracing.py | 6 +- eval_protocol/proxy/Dockerfile.gateway | 26 ++ eval_protocol/proxy/config_no_cache.yaml | 10 + eval_protocol/proxy/docker-compose.yml | 65 ++++ eval_protocol/proxy/proxy_core/__init__.py | 10 + eval_protocol/proxy/proxy_core/app.py | 259 +++++++++++++ eval_protocol/proxy/proxy_core/auth.py | 12 + eval_protocol/proxy/proxy_core/langfuse.py | 358 ++++++++++++++++++ eval_protocol/proxy/proxy_core/litellm.py | 168 ++++++++ eval_protocol/proxy/proxy_core/main.py | 10 + eval_protocol/proxy/proxy_core/models.py | 51 +++ eval_protocol/proxy/proxy_core/redis_utils.py | 48 +++ .../proxy/proxy_core/secrets.json.example | 13 + eval_protocol/proxy/requirements.txt | 6 + 15 files changed, 1043 insertions(+), 2 deletions(-) create mode 100644 eval_protocol/proxy/Dockerfile.gateway create mode 100644 eval_protocol/proxy/config_no_cache.yaml create mode 100644 eval_protocol/proxy/docker-compose.yml create mode 100644 eval_protocol/proxy/proxy_core/__init__.py create mode 100644 eval_protocol/proxy/proxy_core/app.py create mode 100644 eval_protocol/proxy/proxy_core/auth.py create mode 100644 eval_protocol/proxy/proxy_core/langfuse.py create mode 100644 eval_protocol/proxy/proxy_core/litellm.py create mode 100644 eval_protocol/proxy/proxy_core/main.py create mode 100644 eval_protocol/proxy/proxy_core/models.py create mode 100644 eval_protocol/proxy/proxy_core/redis_utils.py create mode 100644 eval_protocol/proxy/proxy_core/secrets.json.example create mode 100644 eval_protocol/proxy/requirements.txt diff --git a/.gitignore b/.gitignore index 42d2b831..d135d29f 100644 --- a/.gitignore +++ b/.gitignore @@ -105,6 +105,9 @@ env.bak/ venv.bak/ *.backup +# Secrets +secrets.json + # Spyder project settings .spyderproject .spyproject diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index af0bf30c..816718fe 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -7,9 +7,9 @@ from __future__ import annotations import logging import requests -import time from datetime import datetime from typing import Any, Dict, List, Optional, Protocol +import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter @@ -349,9 +349,11 @@ def get_evaluation_rows( else: url = f"{self.base_url}/v1/traces" + headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} + result = None try: - response = requests.get(url, params=params, timeout=self.timeout) + response = requests.get(url, params=params, timeout=self.timeout, headers=headers) response.raise_for_status() result = response.json() except requests.exceptions.HTTPError as e: diff --git a/eval_protocol/proxy/Dockerfile.gateway b/eval_protocol/proxy/Dockerfile.gateway new file mode 100644 index 00000000..c791a8f5 --- /dev/null +++ b/eval_protocol/proxy/Dockerfile.gateway @@ -0,0 +1,26 @@ +# Metadata Extraction Gateway - Sits in front of LiteLLM +FROM python:3.11-slim + +WORKDIR /app + +# Prevent Python from buffering stdout/stderr +ENV PYTHONUNBUFFERED=1 + +# Set secrets path to proxy directory +ENV SECRETS_PATH=/app/proxy_core/secrets.json + +# Copy requirements file +COPY ./requirements.txt /app/requirements.txt + +# Install dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the proxy package +COPY ./proxy_core /app/proxy_core + +# Expose port +EXPOSE 4000 + +# Run the gateway as a module +# LITELLM_URL will be set by environment (docker-compose or Cloud Run) +CMD ["python", "-m", "proxy_core.main"] diff --git a/eval_protocol/proxy/config_no_cache.yaml b/eval_protocol/proxy/config_no_cache.yaml new file mode 100644 index 00000000..7adb5a72 --- /dev/null +++ b/eval_protocol/proxy/config_no_cache.yaml @@ -0,0 +1,10 @@ +model_list: + - model_name: "*" + litellm_params: + model: "*" +litellm_settings: + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + drop_params: True +general_settings: + allow_client_side_credentials: true diff --git a/eval_protocol/proxy/docker-compose.yml b/eval_protocol/proxy/docker-compose.yml new file mode 100644 index 00000000..64eec602 --- /dev/null +++ b/eval_protocol/proxy/docker-compose.yml @@ -0,0 +1,65 @@ +services: + # Redis - For tracking assistant message counts + redis: + image: redis:7-alpine + platform: linux/amd64 + container_name: proxy-redis + ports: + - "6379:6379" # Expose for debugging if needed + networks: + - litellm-network + restart: unless-stopped + command: redis-server --appendonly yes + volumes: + - redis-data:/data + + # LiteLLM Backend - Handles actual LLM proxying + litellm-backend: + image: litellm/litellm:v1.77.3-stable + platform: linux/amd64 + container_name: litellm-backend + command: ["--config", "/app/config.yaml", "--port", "4000", "--host", "0.0.0.0"] + environment: + - LANGFUSE_PUBLIC_KEY=dummy # Set dummy public and private key so Langfuse instance initializes in LiteLLM, then real keys get sent in proxy + - LANGFUSE_SECRET_KEY=dummy + - LANGFUSE_HOST=https://langfuse.fireworks.ai + volumes: + - ./config_no_cache.yaml:/app/config.yaml:ro + ports: + - "4001:4000" # Expose on 4001 for direct access if needed + networks: + - litellm-network + restart: unless-stopped + + # Metadata Gateway - Public-facing service that extracts metadata from URLs + metadata-gateway: + build: + context: . + dockerfile: Dockerfile.gateway + container_name: metadata-gateway + environment: + # Point to the LiteLLM backend service + - LITELLM_URL=http://litellm-backend:4000 + - PORT=4000 + # Redis configuration for assistant message counting + - REDIS_HOST=redis + - REDIS_PORT=6379 + # No password for local Redis + - REQUEST_TIMEOUT=300 + # Logging level: INFO (default) + - LOG_LEVEL=INFO + ports: + - "4000:4000" # Main public-facing port + networks: + - litellm-network + depends_on: + - litellm-backend + - redis + restart: unless-stopped + +networks: + litellm-network: + driver: bridge + +volumes: + redis-data: diff --git a/eval_protocol/proxy/proxy_core/__init__.py b/eval_protocol/proxy/proxy_core/__init__.py new file mode 100644 index 00000000..eb6975fa --- /dev/null +++ b/eval_protocol/proxy/proxy_core/__init__.py @@ -0,0 +1,10 @@ +from .models import ProxyConfig +from .app import create_app +from .auth import AuthProvider, NoAuthProvider + +__all__ = [ + "ProxyConfig", + "create_app", + "AuthProvider", + "NoAuthProvider", +] diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py new file mode 100644 index 00000000..d0fe3059 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/app.py @@ -0,0 +1,259 @@ +""" +Metadata Extraction Gateway +A FastAPI service that sits in front of LiteLLM and extracts metadata from URL paths. +""" + +from fastapi import FastAPI, Depends, HTTPException, Request, Query +from typing import Optional, List +import os +import redis +import logging +import json +from pathlib import Path +import sys +from contextlib import asynccontextmanager + +from .models import ProxyConfig, LangfuseTracesResponse +from .auth import AuthProvider, NoAuthProvider +from .litellm import handle_chat_completion, proxy_to_litellm +from .langfuse import fetch_langfuse_traces + +# Configure logging before any other imports (so all modules inherit this config) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +logging.basicConfig( + level=getattr(logging, log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) + +logger = logging.getLogger(__name__) + + +def build_proxy_config() -> ProxyConfig: + """Load environment and secrets, and build ProxyConfig (no Redis).""" + # Env + litellm_url = os.getenv("LITELLM_URL") + if not litellm_url: + raise ValueError("LITELLM_URL environment variable must be set") + request_timeout = float(os.getenv("REQUEST_TIMEOUT", "300.0")) + + # Secrets - use SECRETS_PATH env var if set, otherwise default to proxy/secrets.json + secrets_path_str = os.getenv("SECRETS_PATH") + if secrets_path_str: + secrets_path = Path(secrets_path_str) + else: + secrets_path = Path(__file__).parent / "secrets.json" + if not secrets_path.exists(): + raise ValueError( + "secrets.json not found! Please create it from secrets.json.example:\n" + " cp litellm_proxy_config/proxy/secrets.json.example litellm_proxy_config/proxy/secrets.json\n" + "Then add your Langfuse API keys to secrets.json" + ) + try: + with open(secrets_path, "r") as f: + secrets_config = json.load(f) + langfuse_keys = secrets_config["langfuse_keys"] + default_project_id = secrets_config["default_project_id"] + logger.info(f"Loaded {len(langfuse_keys)} Langfuse project(s) from secrets.json") + except KeyError as e: + raise ValueError(f"Missing required key in secrets.json: {e}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in secrets.json: {e}") + + return ProxyConfig( + litellm_url=litellm_url, + request_timeout=request_timeout, + langfuse_keys=langfuse_keys, + default_project_id=default_project_id, + ) + + +def init_redis() -> redis.Redis: + """Initialize and return a Redis client from environment variables.""" + redis_host = os.getenv("REDIS_HOST") + if not redis_host: + raise ValueError("REDIS_HOST environment variable must be set") + redis_port = int(os.getenv("REDIS_PORT", "6379")) + redis_password = os.getenv("REDIS_PASSWORD") + + try: + client = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password if redis_password else None, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True, + ) + client.ping() + logger.info(f"Connected to Redis at {redis_host}:{redis_port}") + return client + except Exception as e: + raise ConnectionError(f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}") + + +def create_app( + auth_provider: AuthProvider = NoAuthProvider(), +) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + # Build runtime on startup + app.state.config = build_proxy_config() + app.state.redis = init_redis() + try: + yield + finally: + try: + app.state.redis.close() + except Exception: + pass + + app = FastAPI(title="LiteLLM Metadata Proxy", lifespan=lifespan) + + def get_config(request: Request) -> ProxyConfig: + return request.app.state.config + + def get_redis(request: Request) -> redis.Redis: + return request.app.state.redis + + async def require_auth(request: Request) -> None: + auth_header = request.headers.get("authorization", "") + api_key = None + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + + auth_provider.validate(api_key) + return None + + # ===================== + # Chat completion routes + # ===================== + @app.post( + "/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/v1/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/v1/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions" + ) + @app.post( + "/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/v1/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + @app.post( + "/v1/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/encoded_base_url/{encoded_base_url}/chat/completions" + ) + async def chat_completion_with_full_metadata( + rollout_id: str, + invocation_id: str, + experiment_id: str, + run_id: str, + row_id: str, + request: Request, + project_id: Optional[str] = None, + encoded_base_url: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + ): + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + project_id=project_id, + rollout_id=rollout_id, + invocation_id=invocation_id, + experiment_id=experiment_id, + run_id=run_id, + row_id=row_id, + encoded_base_url=encoded_base_url, + ) + + @app.post("/project_id/{project_id}/chat/completions") + @app.post("/v1/project_id/{project_id}/chat/completions") + async def chat_completion_with_project_only( + project_id: str, + request: Request, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + ): + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + project_id=project_id, + ) + + # =============== + # Traces routes + # =============== + @app.get("/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/traces", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) + async def get_langfuse_traces( + tags: List[str] = Query(...), # REQUIRED query param + project_id: Optional[str] = None, + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[str] = None, + to_timestamp: Optional[str] = None, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + return await fetch_langfuse_traces( + config=config, + redis_client=redis_client, + tags=tags, + project_id=project_id, + limit=limit, + sample_size=sample_size, + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + hours_back=hours_back, + from_timestamp=from_timestamp, + to_timestamp=to_timestamp, + sleep_between_gets=sleep_between_gets, + max_retries=max_retries, + ) + + # Health + @app.get("/health") + async def health(): + return {"status": "healthy", "service": "metadata-proxy"} + + # Catch-all + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def catch_all_proxy( + path: str, + request: Request, + config: ProxyConfig = Depends(get_config), + ): + return await proxy_to_litellm(config, path, request) + + return app diff --git a/eval_protocol/proxy/proxy_core/auth.py b/eval_protocol/proxy/proxy_core/auth.py new file mode 100644 index 00000000..8e163512 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/auth.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class AuthProvider(ABC): + @abstractmethod + def validate(self, api_key: Optional[str]) -> Optional[str]: ... + + +class NoAuthProvider(AuthProvider): + def validate(self, api_key: Optional[str]) -> Optional[str]: + return None diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py new file mode 100644 index 00000000..b93fdb18 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -0,0 +1,358 @@ +""" +Traces fetching handler for Langfuse integration. +""" + +import time +import random +import logging +import asyncio +from typing import List, Optional, Dict, Any, Set +from datetime import datetime, timedelta +from fastapi import HTTPException +import redis +from .redis_utils import get_insertion_ids +from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse + +logger = logging.getLogger(__name__) + + +def _extract_tag_value(tags: List[str], prefix: str) -> Optional[str]: + """Extract value from a tag with the given prefix (e.g., 'rollout_id:' or 'insertion_id:').""" + for tag in tags: + if tag.startswith(prefix): + return tag.split(":", 1)[1] + return None + + +def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: + """Convert Langfuse trace object to dict format.""" + timestamp = getattr(trace_full, "timestamp", None) + + return { + "id": trace_full.id, + "name": getattr(trace_full, "name", None), + "user_id": getattr(trace_full, "user_id", None), + "session_id": getattr(trace_full, "session_id", None), + "tags": getattr(trace_full, "tags", []), + "timestamp": str(timestamp) if timestamp else None, + "input": getattr(trace_full, "input", None), + "output": getattr(trace_full, "output", None), + "metadata": getattr(trace_full, "metadata", None), + "observations": [ + { + "id": obs.id, + "type": getattr(obs, "type", None), + "name": getattr(obs, "name", None), + "start_time": str(getattr(obs, "start_time", None)) if getattr(obs, "start_time", None) else None, + "end_time": str(getattr(obs, "end_time", None)) if getattr(obs, "end_time", None) else None, + "input": getattr(obs, "input", None), + "output": getattr(obs, "output", None), + "parent_observation_id": getattr(obs, "parent_observation_id", None), + } + for obs in getattr(trace_full, "observations", []) + ] + if hasattr(trace_full, "observations") + else [], + } + + +async def _fetch_trace_list_with_retry( + langfuse_client: Any, + page: int, + limit: int, + tags: List[str], + user_id: Optional[str], + session_id: Optional[str], + name: Optional[str], + environment: Optional[str], + version: Optional[str], + release: Optional[str], + fields: Optional[str], + from_ts: Optional[datetime], + to_ts: Optional[datetime], + max_retries: int, +) -> Any: + """Fetch trace list with rate limit retry logic.""" + list_retries = 0 + while list_retries < max_retries: + try: + traces = langfuse_client.api.trace.list( + page=page, + limit=limit, + tags=tags, + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + from_timestamp=from_ts, + to_timestamp=to_ts, + order_by="timestamp.desc", + ) + + # If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse + if traces and traces.meta and traces.meta.total_items == 0 and page == 1: + raise Exception("Empty results") + + return traces + except Exception as e: + list_retries += 1 + if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)): + sleep_time = 2**list_retries # Exponential backoff for rate limits + logger.warning( + "Retrying trace.list in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e) + ) + await asyncio.sleep(sleep_time) + elif list_retries == max_retries: + # Return 404 if we've retried max_retries + # TODO: write some tests around proxy exception handling + logger.error("Failed to fetch trace list after %d retries: %s", max_retries, e) + raise HTTPException( + status_code=404, detail=f"Failed to fetch traces after {max_retries} retries: {str(e)}" + ) + else: + # Catch all other exceptions + logger.error("Failed to fetch trace list: %s", e) + raise HTTPException(status_code=500, detail=f"Failed to fetch traces: {str(e)}") + + +async def _fetch_trace_detail_with_retry( + langfuse_client: Any, + trace_id: str, + max_retries: int, +) -> Optional[Any]: + """Fetch full trace details with rate limit retry logic.""" + detail_retries = 0 + while detail_retries < max_retries: + try: + trace_full = langfuse_client.api.trace.get(trace_id) + return trace_full + except Exception as e: + detail_retries += 1 + if "429" in str(e) and detail_retries < max_retries: + sleep_time = 2**detail_retries # Exponential backoff for rate limits + logger.warning( + "Rate limit hit on trace.get(%s), retrying in %ds (attempt %d/%d)", + trace_id, + sleep_time, + detail_retries, + max_retries, + ) + await asyncio.sleep(sleep_time) + elif "Not Found" in str(e) or "404" in str(e): + logger.debug("Trace %s not found, skipping", trace_id) + return None + else: + logger.warning("Failed to fetch trace %s after %d retries: %s", trace_id, max_retries, e) + return None + + +async def fetch_langfuse_traces( + config: ProxyConfig, + redis_client: redis.Redis, + tags: List[str], + project_id: Optional[str] = None, + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[str] = None, + to_timestamp: Optional[str] = None, + sleep_between_gets: float = 2.5, + max_retries: int = 3, +): + """ + Fetch full traces from Langfuse for the specified project. + + This endpoint uses the stored Langfuse keys for the project and polls + traces based on the provided filters. + + SECURITY: + - Tags are REQUIRED and must not be empty + - At least one tag MUST be in the format 'rollout_id:*' + - This prevents accidentally fetching all traces or traces from other clients + + If project_id is not provided, uses the default project. + + Returns a list of full trace objects (including observations) in JSON format. + """ + # Validate tags + if not tags or not any(tag.startswith("rollout_id:") for tag in tags): + raise HTTPException(status_code=422, detail="Tags must include at least one 'rollout_id:*' tag") + + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Validate project_id + if project_id not in config.langfuse_keys: + raise HTTPException( + status_code=404, + detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}", + ) + + # Extract rollout_id from tags for Redis lookup + rollout_id = _extract_tag_value(tags, "rollout_id:") + + try: + # Import the Langfuse adapter + from langfuse import Langfuse + + # Create Langfuse client with the project's keys + langfuse_client = Langfuse( + public_key=config.langfuse_keys[project_id]["public_key"], + secret_key=config.langfuse_keys[project_id]["secret_key"], + host="https://langfuse.fireworks.ai", + ) + + # Parse datetime strings if provided + from_ts = None + to_ts = None + if from_timestamp: + from_ts = datetime.fromisoformat(from_timestamp.replace("Z", "+00:00")) + if to_timestamp: + to_ts = datetime.fromisoformat(to_timestamp.replace("Z", "+00:00")) + + # Determine time window: explicit from/to takes precedence over hours_back + if from_ts is None and to_ts is None and hours_back: + to_ts = datetime.now() + from_ts = to_ts - timedelta(hours=hours_back) + + # Get expected insertion_ids from Redis for completeness checking + expected_ids: Set[str] = set() + if rollout_id: + expected_ids = get_insertion_ids(redis_client, rollout_id) + if not expected_ids: + logger.warning( + f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces." + ) + raise HTTPException( + status_code=500, + detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces.", + ) + + # Track all traces we've collected across retry attempts + trace_ids: Set[str] = set() # Langfuse trace IDs (for deduplication) + all_traces: List[Dict[str, Any]] = [] # Full trace data + insertion_ids: Set[str] = set() # Insertion IDs extracted from traces (for completeness check) + + for retry in range(max_retries): + # On first attempt, use rollout_id tag. On retries, target missing insertion_ids + if retry == 0: + fetch_tags = tags + else: + # Build targeted tags for missing insertion_ids + missing_ids = expected_ids - insertion_ids + fetch_tags = [f"insertion_id:{id}" for id in missing_ids] + logger.info(f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids") + + current_page = 1 + collected = 0 + + while collected < limit: + current_page_limit = min(100, limit - collected) # Langfuse API max is 100 + + # Fetch trace list with rate limit retry logic + traces = await _fetch_trace_list_with_retry( + langfuse_client, + current_page, + current_page_limit, + fetch_tags, + user_id, + session_id, + name, + environment, + version, + release, + fields, + from_ts, + to_ts, + max_retries, + ) + + if not traces or not traces.data: + logger.debug("No more traces found on page %d", current_page) + break + + # For traces we find not in our current list of traces, do trace.get + for trace_info in traces.data: + if trace_info.id in trace_ids: + continue # Skip already processed traces + + if sleep_between_gets > 0: + await asyncio.sleep(sleep_between_gets) + + # Fetch full trace with rate limit retry logic + trace_full = await _fetch_trace_detail_with_retry( + langfuse_client, + trace_info.id, + max_retries, + ) + + if trace_full: + try: + trace_dict = _serialize_trace_to_dict(trace_full) + all_traces.append(trace_dict) + trace_ids.add(trace_info.id) + + # Extract insertion_id for completeness checking + insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:") + if insertion_id: + insertion_ids.add(insertion_id) + + except Exception as e: + logger.warning("Failed to serialize trace %s: %s", trace_info.id, e) + continue + + collected += len(traces.data) + + # Check if we have more pages + if hasattr(traces.meta, "page") and hasattr(traces.meta, "total_pages"): + if traces.meta.page >= traces.meta.total_pages: + break + elif len(traces.data) < current_page_limit: + break + + current_page += 1 + + # If we have all expected completions or more, return traces. At least once is ok. + if expected_ids <= insertion_ids: + if sample_size is not None and len(all_traces) > sample_size: + all_traces = random.sample(all_traces, sample_size) + + return LangfuseTracesResponse( + project_id=project_id, + total_traces=len(all_traces), + traces=[TraceResponse(**trace) for trace in all_traces], + ) + + # If it doesn't match, wait and do loop again (exponential backoff) + if retry < max_retries - 1: + wait_time = 2**retry + logger.info( + f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} expected. Waiting {wait_time}s..." + ) + await asyncio.sleep(wait_time) + + logger.error( + f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions." + ) + raise HTTPException( + status_code=404, + detail=f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions.", + ) + + except ImportError: + raise HTTPException(status_code=500, detail="Langfuse SDK not installed. Install with: pip install langfuse") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching traces from Langfuse: {str(e)}") diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py new file mode 100644 index 00000000..92849064 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -0,0 +1,168 @@ +""" +LiteLLM client - handles all communication with LiteLLM service. +""" + +import json +import base64 +import httpx +import logging +from uuid6 import uuid7 +from fastapi import Request, Response, HTTPException +from typing import Optional +import redis +from .redis_utils import register_insertion_id +from .models import ProxyConfig + +logger = logging.getLogger(__name__) + + +async def handle_chat_completion( + config: ProxyConfig, + redis_client: redis.Redis, + request: Request, + project_id: Optional[str] = None, + rollout_id: Optional[str] = None, + invocation_id: Optional[str] = None, + experiment_id: Optional[str] = None, + run_id: Optional[str] = None, + row_id: Optional[str] = None, + encoded_base_url: Optional[str] = None, +) -> Response: + """ + Handle chat completion requests and forward to LiteLLM. + + If metadata IDs (rollout_id, etc.) are provided, they'll be added as tags + and the assistant message count will be tracked in Redis. + + If encoded_base_url is provided, it will be decoded and added to the request. + """ + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Read the original request body + body = await request.body() + data = json.loads(body) if body else {} + + # Decode and add base_url if provided + if encoded_base_url: + try: + # Decode from URL-safe base64 + decoded_bytes = base64.urlsafe_b64decode(encoded_base_url) + base_url = decoded_bytes.decode("utf-8") + data["base_url"] = base_url + logger.debug(f"Decoded base_url: {base_url}") + except Exception as e: + logger.error(f"Failed to decode base_url: {e}") + raise HTTPException(status_code=400, detail=f"Invalid encoded_base_url: {str(e)}") + + # Extract API key from Authorization header and inject into request body + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + data["api_key"] = api_key + + # If metadata IDs are provided, add them as tags + insertion_id = None + if rollout_id is not None: + insertion_id = str(uuid7()) + + if "metadata" not in data: + data["metadata"] = {} + if "tags" not in data["metadata"]: + data["metadata"]["tags"] = [] + + # Add extracted IDs as tags + data["metadata"]["tags"].extend( + [ + f"rollout_id:{rollout_id}", + f"insertion_id:{insertion_id}", + f"invocation_id:{invocation_id}", + f"experiment_id:{experiment_id}", + f"run_id:{run_id}", + f"row_id:{row_id}", + ] + ) + + # Add Langfuse configuration + data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"] + data["langfuse_secret_key"] = config.langfuse_keys[project_id]["secret_key"] + data["langfuse_host"] = "https://langfuse.fireworks.ai" + + # Forward to LiteLLM's standard /chat/completions endpoint + # Set longer timeout for LLM API calls (LLMs can be slow) + timeout = httpx.Timeout(config.request_timeout) + async with httpx.AsyncClient(timeout=timeout) as client: + # Copy headers from original request but exclude content-length (httpx will set it correctly) + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("content-length", None) # Let httpx calculate the correct length + headers["content-type"] = "application/json" + + # Forward to LiteLLM + litellm_url = f"{config.litellm_url}/chat/completions" + + response = await client.post( + litellm_url, + json=data, # httpx will serialize and set correct Content-Length + headers=headers, + ) + + # Register insertion_id in Redis only on successful response + if response.status_code == 200 and insertion_id is not None and rollout_id is not None: + register_insertion_id(redis_client, rollout_id, insertion_id) + + # Return the response + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + +async def proxy_to_litellm(config: ProxyConfig, path: str, request: Request) -> Response: + """ + Catch-all proxy: Forward any request to LiteLLM, extracting API key from Authorization header. + """ + # Set longer timeout for LLM API calls (LLMs can be slow) + timeout = httpx.Timeout(config.request_timeout) + async with httpx.AsyncClient(timeout=timeout) as client: + # Copy headers + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("content-length", None) + + # Get body + body = await request.body() + + # For POST/PUT/PATCH with JSON, extract API key from header + if request.method in ["POST", "PUT", "PATCH"] and body: + try: + data = json.loads(body) + + # Extract API key from Authorization header + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + data["api_key"] = api_key + + # Re-serialize + body = json.dumps(data).encode() + except json.JSONDecodeError: + pass + + # Forward to LiteLLM + litellm_url = f"{config.litellm_url}/{path}" + + response = await client.request( + method=request.method, + url=litellm_url, + headers=headers, + content=body, + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) diff --git a/eval_protocol/proxy/proxy_core/main.py b/eval_protocol/proxy/proxy_core/main.py new file mode 100644 index 00000000..93b6dcb4 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/main.py @@ -0,0 +1,10 @@ +import os +from .app import create_app + +if __name__ == "__main__": + import uvicorn + + # Build app with default NoAuth for local runs + application = create_app() + port = int(os.getenv("PORT", "4000")) + uvicorn.run(application, host="0.0.0.0", port=port) diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py new file mode 100644 index 00000000..903d77f0 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/models.py @@ -0,0 +1,51 @@ +""" +Models for the LiteLLM Metadata Proxy. +""" + +from pydantic import BaseModel +from typing import Optional, List, Any, Dict + + +class ProxyConfig(BaseModel): + """Configuration model for the LiteLLM Metadata Proxy""" + + litellm_url: str + request_timeout: float = 300.0 + langfuse_keys: Dict[str, Dict[str, str]] + default_project_id: str + + +class ObservationResponse(BaseModel): + """Response model for a single observation within a trace""" + + id: str + type: Optional[str] = None + name: Optional[str] = None + start_time: Optional[str] = None + end_time: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + parent_observation_id: Optional[str] = None + + +class TraceResponse(BaseModel): + """Response model for a single trace""" + + id: str + name: Optional[str] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + tags: List[str] = [] + timestamp: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + metadata: Optional[Any] = None + observations: List[ObservationResponse] = [] + + +class LangfuseTracesResponse(BaseModel): + """Response model for the /traces endpoint""" + + project_id: str + total_traces: int + traces: List[TraceResponse] diff --git a/eval_protocol/proxy/proxy_core/redis_utils.py b/eval_protocol/proxy/proxy_core/redis_utils.py new file mode 100644 index 00000000..fa24c38c --- /dev/null +++ b/eval_protocol/proxy/proxy_core/redis_utils.py @@ -0,0 +1,48 @@ +""" +Redis utilities for tracking chat completions via insertion IDs. +""" + +import logging +from typing import Set +import redis + +logger = logging.getLogger(__name__) + + +def register_insertion_id(redis_client: redis.Redis, rollout_id: str, insertion_id: str) -> bool: + """Register an insertion_id for a rollout_id in Redis. + + Tracks all expected completion insertion_ids for this rollout. + + Args: + rollout_id: The rollout ID + insertion_id: Unique identifier for this specific completion + + Returns: + True if successful, False otherwise + """ + try: + redis_client.sadd(rollout_id, insertion_id) + logger.info(f"Registered insertion_id {insertion_id} for rollout {rollout_id}") + return True + except Exception as e: + logger.error(f"Failed to register insertion_id for {rollout_id}: {e}") + return False + + +def get_insertion_ids(redis_client: redis.Redis, rollout_id: str) -> Set[str]: + """Get all expected insertion_ids for a rollout_id from Redis. + + Args: + rollout_id: The rollout ID to get insertion_ids for + + Returns: + Set of insertion_id strings, empty set if none found or on error + """ + try: + insertion_ids = redis_client.smembers(rollout_id) + logger.debug(f"Found {len(insertion_ids)} expected insertion_ids for rollout {rollout_id}") + return insertion_ids + except Exception as e: + logger.error(f"Failed to get insertion_ids for {rollout_id}: {e}") + return set() diff --git a/eval_protocol/proxy/proxy_core/secrets.json.example b/eval_protocol/proxy/proxy_core/secrets.json.example new file mode 100644 index 00000000..5dd37878 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/secrets.json.example @@ -0,0 +1,13 @@ +{ + "langfuse_keys": { + "your-project-id-1": { + "public_key": "pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "secret_key": "sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + }, + "your-project-id-2": { + "public_key": "pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "secret_key": "sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + } + }, + "default_project_id": "your-project-id-1" +} diff --git a/eval_protocol/proxy/requirements.txt b/eval_protocol/proxy/requirements.txt new file mode 100644 index 00000000..7726397b --- /dev/null +++ b/eval_protocol/proxy/requirements.txt @@ -0,0 +1,6 @@ +fastapi>=0.116.1 +uvicorn>=0.24.0 +httpx>=0.25.0 +redis>=5.0.0 +langfuse>=2.0.0 +uuid6>=2025.0.0 From c2ec0c8bb3f927b3c7f77c8a0e4fb955c7685ea6 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 9 Oct 2025 01:23:30 -0700 Subject: [PATCH 02/15] readme --- eval_protocol/proxy/README.md | 392 ++++++++++++++++++++++++++++++++++ 1 file changed, 392 insertions(+) create mode 100644 eval_protocol/proxy/README.md diff --git a/eval_protocol/proxy/README.md b/eval_protocol/proxy/README.md new file mode 100644 index 00000000..2f9c62aa --- /dev/null +++ b/eval_protocol/proxy/README.md @@ -0,0 +1,392 @@ +# LiteLLM Metadata Extraction Gateway + +A FastAPI-based metadata extraction gateway that sits in front of LiteLLM to inject evaluation metadata into LLM requests and track completions for distributed evaluation workflows. + +## Overview + +The Metadata Gateway is a proxy service that enhances LiteLLM by: +- **Extracting metadata from URL paths** and injecting it as Langfuse tags +- **Managing Langfuse credentials** per-project without exposing them to clients +- **Tracking completion insertion IDs** in Redis for completeness verification +- **Fetching and validating traces** from Langfuse with built-in retry logic + +This enables distributed evaluation systems to track which LLM completions belong to which evaluation runs, ensuring data completeness and proper attribution. + +## Architecture + +``` +┌─────────────┐ +│ Client │ +│ (SDK/CLI) │ +└──────┬──────┘ + │ Authorization: Bearer + │ POST /rollout_id/{id}/invocation_id/{id}/.../chat/completions + ▼ +┌─────────────────────────┐ +│ Metadata Gateway │ +│ (FastAPI Service) │ +│ - Extract metadata │ +│ - Inject Langfuse keys │ +│ - Generate UUID7 IDs │ +└──────┬──────────┬───────┘ + │ │ + ▼ ▼ + ┌────────┐ ┌─────────────┐ + │ Redis │ │ LiteLLM │ + │ │ │ Backend │ + │ Track │ │ │ + │ IDs │ └──────┬──────┘ + └────────┘ │ + ▼ + ┌─────────────┐ + │ Langfuse │ + │ (Tracing) │ + └─────────────┘ +``` + +### Components + +#### 1. **Metadata Gateway** (`proxy_core/`) + - **`app.py`**: Main FastAPI application with route definitions + - **`litellm.py`**: LiteLLM client for forwarding requests + - **`langfuse.py`**: Langfuse trace fetching with retry logic + - **`redis_utils.py`**: Redis operations for insertion ID tracking + - **`models.py`**: Pydantic models for configuration and responses + - **`auth.py`**: Authentication provider interface (extensible) + - **`main.py`**: Entry point for running the service + +#### 2. **Redis** + - Stores insertion IDs per rollout for completeness checking + - Uses Redis Sets: `rollout_id -> {insertion_id_1, insertion_id_2, ...}` + +#### 3. **LiteLLM Backend** + - Standard LiteLLM proxy for routing to LLM providers + - Configured with Langfuse callbacks for automatic tracing + +## Key Features + +### Metadata Injection +URL paths encode evaluation metadata that gets injected as Langfuse tags: +- `rollout_id`: Unique ID for a batch evaluation run +- `invocation_id`: ID for a single invocation within a rollout +- `experiment_id`: Experiment identifier +- `run_id`: Run identifier within an experiment +- `row_id`: Dataset row identifier +- `insertion_id`: Auto-generated UUID7 for this specific completion + +### Completeness Tracking +1. **On chat completion**: Generate UUID7 insertion_id and store in Redis +2. **On trace fetch**: Verify all expected insertion_ids are present in Langfuse +3. **Retry logic**: Automatic retries with exponential backoff for incomplete traces + +### Multi-Project Support +- Store Langfuse credentials for multiple projects in `secrets.json` +- Route requests to the correct project via `project_id` in URL or use default +- Credentials never exposed to clients + +## Setup + +### Prerequisites +- Docker and Docker Compose (recommended) +- Python 3.11+ (for local development) + +### Local Development: Docker Compose + +1. **Create secrets file:** + ```bash + cp proxy_core/secrets.json.example proxy_core/secrets.json + ``` + +2. **Edit `proxy_core/secrets.json`** with your Langfuse credentials. +**Important**: where we have "my-project", you would use the ID of your Langfuse project, similar to format `cmg00asdf0123...`. + ```json + { + "langfuse_keys": { + "my-project": { + "public_key": "pk-lf-...", + "secret_key": "sk-lf-..." + } + }, + "default_project_id": "my-project" + } + ``` + +3. **Start services:** + ```bash + docker-compose up -d + ``` + +4. **Verify services are running:** + ```bash + curl http://localhost:4000/health + # Expected: {"status":"healthy","service":"metadata-proxy"} + ``` + +The gateway will be available at `http://localhost:4000`. + +## API Reference + +### Chat Completions + +#### With Full Metadata +``` +POST /rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions +POST /project_id/{project_id}/rollout_id/{rollout_id}/.../chat/completions +``` + +**Features:** +- Extracts metadata from URL path +- Generates UUID7 insertion_id +- Injects Langfuse credentials +- Tracks insertion_id in Redis +- Forwards to LiteLLM + +**Request:** +```bash +curl -X POST http://localhost:4000/rollout_id/abc123/invocation_id/inv1/experiment_id/exp1/run_id/run1/row_id/row1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-..." \ + -d '{ + "model": "fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +**Response:** Standard OpenAI chat completion response + +#### With Project Only +``` +POST /project_id/{project_id}/chat/completions +``` + +For completions that don't need rollout tracking. + +#### With Encoded Base URL +``` +POST /rollout_id/{rollout_id}/.../encoded_base_url/{encoded_base_url}/chat/completions +``` + +The `encoded_base_url` is base64-encoded URL string injected into the request body as `base_url`. + +### Trace Fetching + +#### Fetch Langfuse Traces +``` +GET /traces?tags=rollout_id:abc123 +GET /project_id/{project_id}/traces?tags=rollout_id:abc123 +``` + +**Required Query Parameters:** +- `tags`: Array of tags (must include at least one `rollout_id:*` tag) + +**Optional Query Parameters:** +- `limit`: Max traces to fetch (default: 100) +- `sample_size`: Random sample size if more traces found +- `user_id`, `session_id`, `name`, `environment`, `version`, `release`: Langfuse filters +- `fields`: Comma-separated fields to include +- `hours_back`: Fetch traces from last N hours +- `from_timestamp`, `to_timestamp`: ISO datetime strings for time range +- `sleep_between_gets`: Delay between trace.get calls (default: 2.5s) +- `max_retries`: Retry attempts for incomplete traces (default: 3) + +**Completeness Logic:** +1. Fetches traces from Langfuse matching tags +2. Extracts insertion_ids from trace tags +3. Compares with expected insertion_ids in Redis +4. Retries with exponential backoff if incomplete +5. Returns 404 if still incomplete after max_retries + +**Response:** +```json +{ + "project_id": "my-project", + "total_traces": 42, + "traces": [ + { + "id": "trace-123", + "name": "chat-completion", + "tags": ["rollout_id:abc123", "insertion_id:uuid7..."], + "input": {...}, + "output": {...}, + "observations": [...] + } + ] +} +``` + +### Health Check +``` +GET /health +``` + +Returns service health status. + +### Catch-All Proxy +``` +ANY /{path} +``` + +Forwards any other request to LiteLLM backend with API key injection. + +## Configuration + +### Environment Variables + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| `LITELLM_URL` | Yes | - | URL of LiteLLM backend | +| `REDIS_HOST` | Yes | - | Redis hostname | +| `REDIS_PORT` | No | 6379 | Redis port | +| `REDIS_PASSWORD` | No | - | Redis password | +| `SECRETS_PATH` | No | `proxy_core/secrets.json` | Path to secrets file | +| `REQUEST_TIMEOUT` | No | 300.0 | Request timeout in seconds | +| `LOG_LEVEL` | No | INFO | Logging level | +| `PORT` | No | 4000 | Gateway port | + +### Secrets Configuration + +Create `proxy_core/secrets.json`: +```json +{ + "langfuse_keys": { + "project-1": { + "public_key": "pk-lf-...", + "secret_key": "sk-lf-..." + }, + "project-2": { + "public_key": "pk-lf-...", + "secret_key": "sk-lf-..." + } + }, + "default_project_id": "project-1" +} +``` + +**Security:** Add `secrets.json` to `.gitignore` (already configured). + +### LiteLLM Configuration + +The `config_no_cache.yaml` configures LiteLLM: +```yaml +model_list: + - model_name: "*" + litellm_params: + model: "*" +litellm_settings: + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + drop_params: True +general_settings: + allow_client_side_credentials: true +``` + +Key settings: +- **Wildcard model support**: Route any model to any provider +- **Langfuse callbacks**: Automatic tracing on success/failure +- **Client-side credentials**: Accept API keys from request body + +## Security Considerations + +### Authentication +- **Default**: No authentication (`NoAuthProvider`) +- **Extensible**: Implement custom `AuthProvider` for production +- **API Keys**: Client API keys forwarded to LiteLLM, never stored + +### Trace Fetching Security +- **Required rollout_id tag**: Prevents fetching all traces +- **Project isolation**: Projects can only access their own Langfuse data +- **Optional auth**: `/traces` endpoint can require authentication + +### Best Practices +1. **Never commit `secrets.json`** - use environment variables in production +2. **Use HTTPS** in production deployments +3. **Implement proper authentication** for production use +4. **Rotate Langfuse keys** regularly +5. **Monitor Redis memory** usage for large rollouts + +## Deployment + +### Docker Compose (Development) +```bash +docker-compose up -d +``` + +### Kubernetes +Create deployment with: +- Secrets for `secrets.json` and Redis credentials +- Service for internal/external access +- ConfigMap for LiteLLM config +- Redis StatefulSet or managed Redis service + +## Development + +### Project Structure +``` +eval_protocol/proxy/ +├── proxy_core/ # Main application package +│ ├── __init__.py +│ ├── app.py # FastAPI routes +│ ├── litellm.py # LiteLLM client +│ ├── langfuse.py # Langfuse integration +│ ├── redis_utils.py # Redis operations +│ ├── models.py # Pydantic models +│ ├── auth.py # Authentication +│ ├── main.py # Entry point +│ └── secrets.json.example +├── docker-compose.yml # Local development stack +├── Dockerfile.gateway # Gateway container +├── config_no_cache.yaml # LiteLLM config +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +### Adding Custom Authentication + +Extend `AuthProvider` in `auth.py`: +```python +from .auth import AuthProvider +from fastapi import HTTPException + +class MyAuthProvider(AuthProvider): + def validate(self, api_key: Optional[str]) -> Optional[str]: + if not api_key or not self.is_valid(api_key): + raise HTTPException(status_code=401, detail="Invalid API key") + return api_key + + def is_valid(self, api_key: str) -> bool: + # Your validation logic + return True +``` + +Then pass it to `create_app`: +```python +from proxy_core import create_app +from my_auth import MyAuthProvider + +app = create_app(auth_provider=MyAuthProvider()) +``` + +### Testing + +#### Test chat completion: +```bash +curl -X POST http://localhost:4000/rollout_id/test123/invocation_id/inv1/experiment_id/exp1/run_id/run1/row_id/row1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $FIREWORKS_API_KEY" \ + -d '{ + "model": "fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct", + "messages": [{"role": "user", "content": "Say hello"}] + }' +``` + +#### Test trace fetching: +```bash +curl "http://localhost:4000/traces?tags=rollout_id:test123" \ + -H "Authorization: Bearer your-auth-token" +``` + +#### Check Redis: +```bash +redis-cli +> SMEMBERS test123 # View insertion_ids for rollout +``` From 1eef32fcf8f230ad2ee2b46f90a4f62c77fd89a7 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 9 Oct 2025 01:30:59 -0700 Subject: [PATCH 03/15] import --- eval_protocol/__init__.py | 10 ++++++++++ eval_protocol/proxy/__init__.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 eval_protocol/proxy/__init__.py diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 6b96426f..9cb800e0 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -70,6 +70,13 @@ except ImportError: WeaveAdapter = None +try: + from .proxy import create_app, AuthProvider +except ImportError: + create_app = None + AuthProvider = None + + warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") __all__ = [ @@ -130,6 +137,9 @@ "RolloutMetadata", "StatusResponse", "create_langfuse_config_tags", + # Proxy + "create_app", + "AuthProvider", ] from . import _version diff --git a/eval_protocol/proxy/__init__.py b/eval_protocol/proxy/__init__.py new file mode 100644 index 00000000..66765195 --- /dev/null +++ b/eval_protocol/proxy/__init__.py @@ -0,0 +1,15 @@ +""" +LiteLLM Metadata Extraction Gateway + +A proxy service for extracting evaluation metadata from URL paths and managing +Langfuse tracing for distributed evaluation workflows. +""" + +from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig + +__all__ = [ + "create_app", + "AuthProvider", + "NoAuthProvider", + "ProxyConfig", +] From c3764cc70d1fe790622bcf6871d3debf11b65de3 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 9 Oct 2025 13:19:41 -0700 Subject: [PATCH 04/15] address comments --- .gitignore | 2 +- eval_protocol/proxy/Dockerfile.gateway | 3 - eval_protocol/proxy/README.md | 72 +++++++++---------- eval_protocol/proxy/docker-compose.yml | 3 +- eval_protocol/proxy/proxy_core/app.py | 31 ++++---- eval_protocol/proxy/proxy_core/auth.py | 5 +- eval_protocol/proxy/proxy_core/langfuse.py | 2 +- eval_protocol/proxy/proxy_core/litellm.py | 2 +- eval_protocol/proxy/proxy_core/models.py | 1 + .../proxy/proxy_core/secrets.json.example | 13 ---- .../proxy/proxy_core/secrets.yaml.example | 14 ++++ eval_protocol/proxy/requirements.txt | 1 + tests/remote_server/test_remote_fireworks.py | 3 +- 13 files changed, 72 insertions(+), 80 deletions(-) delete mode 100644 eval_protocol/proxy/proxy_core/secrets.json.example create mode 100644 eval_protocol/proxy/proxy_core/secrets.yaml.example diff --git a/.gitignore b/.gitignore index d135d29f..d4f0d2df 100644 --- a/.gitignore +++ b/.gitignore @@ -106,7 +106,7 @@ venv.bak/ *.backup # Secrets -secrets.json +secrets.yaml # Spyder project settings .spyderproject diff --git a/eval_protocol/proxy/Dockerfile.gateway b/eval_protocol/proxy/Dockerfile.gateway index c791a8f5..a9308faa 100644 --- a/eval_protocol/proxy/Dockerfile.gateway +++ b/eval_protocol/proxy/Dockerfile.gateway @@ -6,9 +6,6 @@ WORKDIR /app # Prevent Python from buffering stdout/stderr ENV PYTHONUNBUFFERED=1 -# Set secrets path to proxy directory -ENV SECRETS_PATH=/app/proxy_core/secrets.json - # Copy requirements file COPY ./requirements.txt /app/requirements.txt diff --git a/eval_protocol/proxy/README.md b/eval_protocol/proxy/README.md index 2f9c62aa..dcf656b9 100644 --- a/eval_protocol/proxy/README.md +++ b/eval_protocol/proxy/README.md @@ -80,7 +80,7 @@ URL paths encode evaluation metadata that gets injected as Langfuse tags: 3. **Retry logic**: Automatic retries with exponential backoff for incomplete traces ### Multi-Project Support -- Store Langfuse credentials for multiple projects in `secrets.json` +- Store Langfuse credentials for multiple projects in `secrets.yaml` - Route requests to the correct project via `project_id` in URL or use default - Credentials never exposed to clients @@ -94,21 +94,17 @@ URL paths encode evaluation metadata that gets injected as Langfuse tags: 1. **Create secrets file:** ```bash - cp proxy_core/secrets.json.example proxy_core/secrets.json + cp proxy_core/secrets.yaml.example proxy_core/secrets.yaml ``` -2. **Edit `proxy_core/secrets.json`** with your Langfuse credentials. -**Important**: where we have "my-project", you would use the ID of your Langfuse project, similar to format `cmg00asdf0123...`. - ```json - { - "langfuse_keys": { - "my-project": { - "public_key": "pk-lf-...", - "secret_key": "sk-lf-..." - } - }, - "default_project_id": "my-project" - } +2. **Edit `proxy_core/secrets.yaml`** with your Langfuse credentials. +**Important**: use your real Langfuse project ID (e.g. `cmg00asdf0123...`). + ```yaml + langfuse_keys: + my-project: + public_key: pk-lf-... + secret_key: sk-lf-... + default_project_id: my-project ``` 3. **Start services:** @@ -238,31 +234,27 @@ Forwards any other request to LiteLLM backend with API key injection. | `REDIS_HOST` | Yes | - | Redis hostname | | `REDIS_PORT` | No | 6379 | Redis port | | `REDIS_PASSWORD` | No | - | Redis password | -| `SECRETS_PATH` | No | `proxy_core/secrets.json` | Path to secrets file | -| `REQUEST_TIMEOUT` | No | 300.0 | Request timeout in seconds | +| `SECRETS_PATH` | No | `proxy_core/secrets.yaml` | Path to secrets file (YAML) | +| `LANGFUSE_HOST` | No | `https://cloud.langfuse.com` | Langfuse base URL | +| `REQUEST_TIMEOUT` | No | 300.0 | Request timeout (LLM calls) in seconds | | `LOG_LEVEL` | No | INFO | Logging level | | `PORT` | No | 4000 | Gateway port | ### Secrets Configuration -Create `proxy_core/secrets.json`: -```json -{ - "langfuse_keys": { - "project-1": { - "public_key": "pk-lf-...", - "secret_key": "sk-lf-..." - }, - "project-2": { - "public_key": "pk-lf-...", - "secret_key": "sk-lf-..." - } - }, - "default_project_id": "project-1" -} +Create `proxy_core/secrets.yaml`: +```yaml +langfuse_keys: + project-1: + public_key: pk-lf-... + secret_key: sk-lf-... + project-2: + public_key: pk-lf-... + secret_key: sk-lf-... +default_project_id: project-1 ``` -**Security:** Add `secrets.json` to `.gitignore` (already configured). +**Security:** `secrets.yaml` is ignored via `.gitignore`. ### LiteLLM Configuration @@ -332,7 +324,7 @@ eval_protocol/proxy/ │ ├── models.py # Pydantic models │ ├── auth.py # Authentication │ ├── main.py # Entry point -│ └── secrets.json.example +│ └── secrets.yaml.example ├── docker-compose.yml # Local development stack ├── Dockerfile.gateway # Gateway container ├── config_no_cache.yaml # LiteLLM config @@ -345,17 +337,17 @@ eval_protocol/proxy/ Extend `AuthProvider` in `auth.py`: ```python from .auth import AuthProvider -from fastapi import HTTPException +from fastapi import HTTPException, Request class MyAuthProvider(AuthProvider): - def validate(self, api_key: Optional[str]) -> Optional[str]: - if not api_key or not self.is_valid(api_key): + def validate(self, request: Request) -> Optional[str]: + api_key = None + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_key = auth_header.replace("Bearer ", "").strip() + if not api_key: raise HTTPException(status_code=401, detail="Invalid API key") return api_key - - def is_valid(self, api_key: str) -> bool: - # Your validation logic - return True ``` Then pass it to `create_app`: diff --git a/eval_protocol/proxy/docker-compose.yml b/eval_protocol/proxy/docker-compose.yml index 64eec602..12789e41 100644 --- a/eval_protocol/proxy/docker-compose.yml +++ b/eval_protocol/proxy/docker-compose.yml @@ -22,7 +22,6 @@ services: environment: - LANGFUSE_PUBLIC_KEY=dummy # Set dummy public and private key so Langfuse instance initializes in LiteLLM, then real keys get sent in proxy - LANGFUSE_SECRET_KEY=dummy - - LANGFUSE_HOST=https://langfuse.fireworks.ai volumes: - ./config_no_cache.yaml:/app/config.yaml:ro ports: @@ -48,6 +47,8 @@ services: - REQUEST_TIMEOUT=300 # Logging level: INFO (default) - LOG_LEVEL=INFO + # Langfuse and secrets + - SECRETS_PATH=/app/proxy_core/secrets.yaml ports: - "4000:4000" # Main public-facing port networks: diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py index d0fe3059..567c0fcc 100644 --- a/eval_protocol/proxy/proxy_core/app.py +++ b/eval_protocol/proxy/proxy_core/app.py @@ -8,7 +8,7 @@ import os import redis import logging -import json +import yaml from pathlib import Path import sys from contextlib import asynccontextmanager @@ -36,33 +36,35 @@ def build_proxy_config() -> ProxyConfig: if not litellm_url: raise ValueError("LITELLM_URL environment variable must be set") request_timeout = float(os.getenv("REQUEST_TIMEOUT", "300.0")) + langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") - # Secrets - use SECRETS_PATH env var if set, otherwise default to proxy/secrets.json + # Secrets - use SECRETS_PATH env var if set, otherwise default to proxy/secrets.yaml secrets_path_str = os.getenv("SECRETS_PATH") if secrets_path_str: secrets_path = Path(secrets_path_str) else: - secrets_path = Path(__file__).parent / "secrets.json" + secrets_path = Path(__file__).parent / "secrets.yaml" if not secrets_path.exists(): raise ValueError( - "secrets.json not found! Please create it from secrets.json.example:\n" - " cp litellm_proxy_config/proxy/secrets.json.example litellm_proxy_config/proxy/secrets.json\n" - "Then add your Langfuse API keys to secrets.json" + "Secrets file not found! Please create it from secrets.yaml.example:\n" + " cp eval_protocol/proxy/proxy_core/secrets.yaml.example eval_protocol/proxy/proxy_core/secrets.yaml\n" + "Then add your Langfuse API keys to the secrets file" ) try: with open(secrets_path, "r") as f: - secrets_config = json.load(f) + secrets_config = yaml.safe_load(f) langfuse_keys = secrets_config["langfuse_keys"] default_project_id = secrets_config["default_project_id"] - logger.info(f"Loaded {len(langfuse_keys)} Langfuse project(s) from secrets.json") + logger.info(f"Loaded {len(langfuse_keys)} Langfuse project(s) from {secrets_path.name}") except KeyError as e: - raise ValueError(f"Missing required key in secrets.json: {e}") - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in secrets.json: {e}") + raise ValueError(f"Missing required key in secrets file: {e}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid format in secrets file {secrets_path.name}: {e}") return ProxyConfig( litellm_url=litellm_url, request_timeout=request_timeout, + langfuse_host=langfuse_host, langfuse_keys=langfuse_keys, default_project_id=default_project_id, ) @@ -118,12 +120,7 @@ def get_redis(request: Request) -> redis.Redis: return request.app.state.redis async def require_auth(request: Request) -> None: - auth_header = request.headers.get("authorization", "") - api_key = None - if auth_header.startswith("Bearer "): - api_key = auth_header.replace("Bearer ", "").strip() - - auth_provider.validate(api_key) + auth_provider.validate(request) return None # ===================== diff --git a/eval_protocol/proxy/proxy_core/auth.py b/eval_protocol/proxy/proxy_core/auth.py index 8e163512..8479cdb7 100644 --- a/eval_protocol/proxy/proxy_core/auth.py +++ b/eval_protocol/proxy/proxy_core/auth.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod from typing import Optional +from fastapi import Request class AuthProvider(ABC): @abstractmethod - def validate(self, api_key: Optional[str]) -> Optional[str]: ... + def validate(self, request: Request) -> Optional[str]: ... class NoAuthProvider(AuthProvider): - def validate(self, api_key: Optional[str]) -> Optional[str]: + def validate(self, request: Request) -> Optional[str]: return None diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index b93fdb18..bdab34cb 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -210,7 +210,7 @@ async def fetch_langfuse_traces( langfuse_client = Langfuse( public_key=config.langfuse_keys[project_id]["public_key"], secret_key=config.langfuse_keys[project_id]["secret_key"], - host="https://langfuse.fireworks.ai", + host=config.langfuse_host, ) # Parse datetime strings if provided diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py index 92849064..557d8301 100644 --- a/eval_protocol/proxy/proxy_core/litellm.py +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -87,7 +87,7 @@ async def handle_chat_completion( # Add Langfuse configuration data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"] data["langfuse_secret_key"] = config.langfuse_keys[project_id]["secret_key"] - data["langfuse_host"] = "https://langfuse.fireworks.ai" + data["langfuse_host"] = config.langfuse_host # Forward to LiteLLM's standard /chat/completions endpoint # Set longer timeout for LLM API calls (LLMs can be slow) diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py index 903d77f0..e77a2328 100644 --- a/eval_protocol/proxy/proxy_core/models.py +++ b/eval_protocol/proxy/proxy_core/models.py @@ -11,6 +11,7 @@ class ProxyConfig(BaseModel): litellm_url: str request_timeout: float = 300.0 + langfuse_host: str langfuse_keys: Dict[str, Dict[str, str]] default_project_id: str diff --git a/eval_protocol/proxy/proxy_core/secrets.json.example b/eval_protocol/proxy/proxy_core/secrets.json.example deleted file mode 100644 index 5dd37878..00000000 --- a/eval_protocol/proxy/proxy_core/secrets.json.example +++ /dev/null @@ -1,13 +0,0 @@ -{ - "langfuse_keys": { - "your-project-id-1": { - "public_key": "pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "secret_key": "sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - }, - "your-project-id-2": { - "public_key": "pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "secret_key": "sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - } - }, - "default_project_id": "your-project-id-1" -} diff --git a/eval_protocol/proxy/proxy_core/secrets.yaml.example b/eval_protocol/proxy/proxy_core/secrets.yaml.example new file mode 100644 index 00000000..e010d657 --- /dev/null +++ b/eval_protocol/proxy/proxy_core/secrets.yaml.example @@ -0,0 +1,14 @@ +langfuse_keys: + project_1_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_2_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_3_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + project_4_id: + public_key: pk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + secret_key: sk-lf-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx +default_project_id: project_1_id diff --git a/eval_protocol/proxy/requirements.txt b/eval_protocol/proxy/requirements.txt index 7726397b..15d21d0b 100644 --- a/eval_protocol/proxy/requirements.txt +++ b/eval_protocol/proxy/requirements.txt @@ -4,3 +4,4 @@ httpx>=0.25.0 redis>=5.0.0 langfuse>=2.0.0 uuid6>=2025.0.0 +PyYAML>=6.0.0 diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 3050b1f5..3bc41854 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: @@ -65,6 +65,7 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", + model_base_url="http://localhost:4000", timeout_seconds=180, output_data_loader=fireworks_output_data_loader, ), From 98d81a75608e700f91e697f88c149641328d9507 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 00:24:47 -0700 Subject: [PATCH 05/15] Add preprocess hooks --- eval_protocol/proxy/README.md | 26 ----- eval_protocol/proxy/__init__.py | 4 +- eval_protocol/proxy/docker-compose.yml | 1 + eval_protocol/proxy/proxy_core/__init__.py | 4 +- eval_protocol/proxy/proxy_core/app.py | 122 +++++++++++++-------- eval_protocol/proxy/proxy_core/auth.py | 11 +- eval_protocol/proxy/proxy_core/langfuse.py | 52 +++++---- eval_protocol/proxy/proxy_core/litellm.py | 31 +++--- eval_protocol/proxy/proxy_core/models.py | 42 ++++++- 9 files changed, 177 insertions(+), 116 deletions(-) diff --git a/eval_protocol/proxy/README.md b/eval_protocol/proxy/README.md index dcf656b9..dc020890 100644 --- a/eval_protocol/proxy/README.md +++ b/eval_protocol/proxy/README.md @@ -332,32 +332,6 @@ eval_protocol/proxy/ └── README.md # This file ``` -### Adding Custom Authentication - -Extend `AuthProvider` in `auth.py`: -```python -from .auth import AuthProvider -from fastapi import HTTPException, Request - -class MyAuthProvider(AuthProvider): - def validate(self, request: Request) -> Optional[str]: - api_key = None - auth_header = request.headers.get("authorization", "") - if auth_header.startswith("Bearer "): - api_key = auth_header.replace("Bearer ", "").strip() - if not api_key: - raise HTTPException(status_code=401, detail="Invalid API key") - return api_key -``` - -Then pass it to `create_app`: -```python -from proxy_core import create_app -from my_auth import MyAuthProvider - -app = create_app(auth_provider=MyAuthProvider()) -``` - ### Testing #### Test chat completion: diff --git a/eval_protocol/proxy/__init__.py b/eval_protocol/proxy/__init__.py index 66765195..93bda257 100644 --- a/eval_protocol/proxy/__init__.py +++ b/eval_protocol/proxy/__init__.py @@ -5,11 +5,13 @@ Langfuse tracing for distributed evaluation workflows. """ -from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig +from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig, ChatParams, TracesParams __all__ = [ "create_app", "AuthProvider", "NoAuthProvider", "ProxyConfig", + "ChatParams", + "TracesParams", ] diff --git a/eval_protocol/proxy/docker-compose.yml b/eval_protocol/proxy/docker-compose.yml index 12789e41..4128e654 100644 --- a/eval_protocol/proxy/docker-compose.yml +++ b/eval_protocol/proxy/docker-compose.yml @@ -49,6 +49,7 @@ services: - LOG_LEVEL=INFO # Langfuse and secrets - SECRETS_PATH=/app/proxy_core/secrets.yaml + - LANGFUSE_HOST=${LANGFUSE_HOST:-https://cloud.langfuse.com} ports: - "4000:4000" # Main public-facing port networks: diff --git a/eval_protocol/proxy/proxy_core/__init__.py b/eval_protocol/proxy/proxy_core/__init__.py index eb6975fa..d221be71 100644 --- a/eval_protocol/proxy/proxy_core/__init__.py +++ b/eval_protocol/proxy/proxy_core/__init__.py @@ -1,9 +1,11 @@ -from .models import ProxyConfig +from .models import ProxyConfig, ChatParams, TracesParams from .app import create_app from .auth import AuthProvider, NoAuthProvider __all__ = [ "ProxyConfig", + "ChatParams", + "TracesParams", "create_app", "AuthProvider", "NoAuthProvider", diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py index 567c0fcc..4b76d809 100644 --- a/eval_protocol/proxy/proxy_core/app.py +++ b/eval_protocol/proxy/proxy_core/app.py @@ -4,7 +4,7 @@ """ from fastapi import FastAPI, Depends, HTTPException, Request, Query -from typing import Optional, List +from typing import Optional, Callable, Dict, Any, List import os import redis import logging @@ -13,24 +13,28 @@ import sys from contextlib import asynccontextmanager -from .models import ProxyConfig, LangfuseTracesResponse +from .models import ProxyConfig, LangfuseTracesResponse, TracesParams, ChatParams, ChatRequestHook, TracesRequestHook from .auth import AuthProvider, NoAuthProvider from .litellm import handle_chat_completion, proxy_to_litellm from .langfuse import fetch_langfuse_traces # Configure logging before any other imports (so all modules inherit this config) log_level = os.getenv("LOG_LEVEL", "INFO").upper() -logging.basicConfig( - level=getattr(logging, log_level), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler(sys.stdout)], -) +if not logging.getLogger().hasHandlers(): + logging.basicConfig( + level=getattr(logging, log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) logger = logging.getLogger(__name__) -def build_proxy_config() -> ProxyConfig: - """Load environment and secrets, and build ProxyConfig (no Redis).""" +def build_proxy_config( + preprocess_chat_request: Optional[ChatRequestHook] = None, + preprocess_traces_request: Optional[TracesRequestHook] = None, +) -> ProxyConfig: + """Load environment and secrets, and build ProxyConfig""" # Env litellm_url = os.getenv("LITELLM_URL") if not litellm_url: @@ -67,6 +71,8 @@ def build_proxy_config() -> ProxyConfig: langfuse_host=langfuse_host, langfuse_keys=langfuse_keys, default_project_id=default_project_id, + preprocess_chat_request=preprocess_chat_request, + preprocess_traces_request=preprocess_traces_request, ) @@ -97,12 +103,15 @@ def init_redis() -> redis.Redis: def create_app( auth_provider: AuthProvider = NoAuthProvider(), + preprocess_chat_request: Optional[ChatRequestHook] = None, + preprocess_traces_request: Optional[TracesRequestHook] = None, ) -> FastAPI: @asynccontextmanager async def lifespan(app: FastAPI): # Build runtime on startup - app.state.config = build_proxy_config() + app.state.config = build_proxy_config(preprocess_chat_request, preprocess_traces_request) app.state.redis = init_redis() + try: yield finally: @@ -119,8 +128,46 @@ def get_config(request: Request) -> ProxyConfig: def get_redis(request: Request) -> redis.Redis: return request.app.state.redis + def get_traces_params( + tags: Optional[List[str]] = Query(default=None), + project_id: Optional[str] = None, + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[str] = None, + to_timestamp: Optional[str] = None, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + ) -> TracesParams: + return TracesParams( + tags=tags, + project_id=project_id, + limit=limit, + sample_size=sample_size, + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + hours_back=hours_back, + from_timestamp=from_timestamp, + to_timestamp=to_timestamp, + sleep_between_gets=sleep_between_gets, + max_retries=max_retries, + ) + async def require_auth(request: Request) -> None: - auth_provider.validate(request) + account_id = auth_provider.validate_and_return_account_id(request) + request.state.account_id = account_id return None # ===================== @@ -161,11 +208,9 @@ async def chat_completion_with_full_metadata( encoded_base_url: Optional[str] = None, config: ProxyConfig = Depends(get_config), redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), ): - return await handle_chat_completion( - config=config, - redis_client=redis_client, - request=request, + params = ChatParams( project_id=project_id, rollout_id=rollout_id, invocation_id=invocation_id, @@ -174,6 +219,12 @@ async def chat_completion_with_full_metadata( row_id=row_id, encoded_base_url=encoded_base_url, ) + return await handle_chat_completion( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) @app.post("/project_id/{project_id}/chat/completions") @app.post("/v1/project_id/{project_id}/chat/completions") @@ -182,12 +233,14 @@ async def chat_completion_with_project_only( request: Request, config: ProxyConfig = Depends(get_config), redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), ): + params = ChatParams(project_id=project_id) return await handle_chat_completion( config=config, redis_client=redis_client, request=request, - project_id=project_id, + params=params, ) # =============== @@ -198,45 +251,20 @@ async def chat_completion_with_project_only( @app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) @app.get("/v1/project_id/{project_id}/traces", response_model=LangfuseTracesResponse) async def get_langfuse_traces( - tags: List[str] = Query(...), # REQUIRED query param + request: Request, + params: TracesParams = Depends(get_traces_params), project_id: Optional[str] = None, - limit: int = 100, - sample_size: Optional[int] = None, - user_id: Optional[str] = None, - session_id: Optional[str] = None, - name: Optional[str] = None, - environment: Optional[str] = None, - version: Optional[str] = None, - release: Optional[str] = None, - fields: Optional[str] = None, - hours_back: Optional[int] = None, - from_timestamp: Optional[str] = None, - to_timestamp: Optional[str] = None, - sleep_between_gets: float = 2.5, - max_retries: int = 3, config: ProxyConfig = Depends(get_config), redis_client: redis.Redis = Depends(get_redis), _: None = Depends(require_auth), ) -> LangfuseTracesResponse: + if project_id is not None: + params.project_id = project_id return await fetch_langfuse_traces( config=config, redis_client=redis_client, - tags=tags, - project_id=project_id, - limit=limit, - sample_size=sample_size, - user_id=user_id, - session_id=session_id, - name=name, - environment=environment, - version=version, - release=release, - fields=fields, - hours_back=hours_back, - from_timestamp=from_timestamp, - to_timestamp=to_timestamp, - sleep_between_gets=sleep_between_gets, - max_retries=max_retries, + request=request, + params=params, ) # Health diff --git a/eval_protocol/proxy/proxy_core/auth.py b/eval_protocol/proxy/proxy_core/auth.py index 8479cdb7..fed24845 100644 --- a/eval_protocol/proxy/proxy_core/auth.py +++ b/eval_protocol/proxy/proxy_core/auth.py @@ -1,13 +1,18 @@ from abc import ABC, abstractmethod -from typing import Optional +import logging from fastapi import Request +from fastapi import HTTPException +import httpx +from typing import Optional + +logger = logging.getLogger(__name__) class AuthProvider(ABC): @abstractmethod - def validate(self, request: Request) -> Optional[str]: ... + def validate_and_return_account_id(self, request: Request) -> Optional[str]: ... class NoAuthProvider(AuthProvider): - def validate(self, request: Request) -> Optional[str]: + def validate_and_return_account_id(self, request: Request) -> Optional[str]: return None diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index bdab34cb..ce128ee3 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -8,16 +8,18 @@ import asyncio from typing import List, Optional, Dict, Any, Set from datetime import datetime, timedelta -from fastapi import HTTPException +from fastapi import HTTPException, Request import redis from .redis_utils import get_insertion_ids -from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse +from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse, TracesParams logger = logging.getLogger(__name__) -def _extract_tag_value(tags: List[str], prefix: str) -> Optional[str]: +def _extract_tag_value(tags: Optional[List[str]], prefix: str) -> Optional[str]: """Extract value from a tag with the given prefix (e.g., 'rollout_id:' or 'insertion_id:').""" + if not tags: + return None for tag in tags: if tag.startswith(prefix): return tag.split(":", 1)[1] @@ -60,7 +62,7 @@ async def _fetch_trace_list_with_retry( langfuse_client: Any, page: int, limit: int, - tags: List[str], + tags: Optional[List[str]], user_id: Optional[str], session_id: Optional[str], name: Optional[str], @@ -152,22 +154,8 @@ async def _fetch_trace_detail_with_retry( async def fetch_langfuse_traces( config: ProxyConfig, redis_client: redis.Redis, - tags: List[str], - project_id: Optional[str] = None, - limit: int = 100, - sample_size: Optional[int] = None, - user_id: Optional[str] = None, - session_id: Optional[str] = None, - name: Optional[str] = None, - environment: Optional[str] = None, - version: Optional[str] = None, - release: Optional[str] = None, - fields: Optional[str] = None, - hours_back: Optional[int] = None, - from_timestamp: Optional[str] = None, - to_timestamp: Optional[str] = None, - sleep_between_gets: float = 2.5, - max_retries: int = 3, + request: Request, + params: TracesParams, ): """ Fetch full traces from Langfuse for the specified project. @@ -184,9 +172,27 @@ async def fetch_langfuse_traces( Returns a list of full trace objects (including observations) in JSON format. """ - # Validate tags - if not tags or not any(tag.startswith("rollout_id:") for tag in tags): - raise HTTPException(status_code=422, detail="Tags must include at least one 'rollout_id:*' tag") + + # Preprocess traces request + if config.preprocess_traces_request: + params = config.preprocess_traces_request(request, params) + + tags = params.tags + project_id = params.project_id + limit = params.limit + sample_size = params.sample_size + user_id = params.user_id + session_id = params.session_id + name = params.name + environment = params.environment + version = params.version + release = params.release + fields = params.fields + hours_back = params.hours_back + from_timestamp = params.from_timestamp + to_timestamp = params.to_timestamp + sleep_between_gets = params.sleep_between_gets + max_retries = params.max_retries # Use default project if not specified if project_id is None: diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py index 557d8301..bade23fb 100644 --- a/eval_protocol/proxy/proxy_core/litellm.py +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -11,7 +11,7 @@ from typing import Optional import redis from .redis_utils import register_insertion_id -from .models import ProxyConfig +from .models import ProxyConfig, ChatParams logger = logging.getLogger(__name__) @@ -20,13 +20,7 @@ async def handle_chat_completion( config: ProxyConfig, redis_client: redis.Redis, request: Request, - project_id: Optional[str] = None, - rollout_id: Optional[str] = None, - invocation_id: Optional[str] = None, - experiment_id: Optional[str] = None, - run_id: Optional[str] = None, - row_id: Optional[str] = None, - encoded_base_url: Optional[str] = None, + params: ChatParams, ) -> Response: """ Handle chat completion requests and forward to LiteLLM. @@ -36,14 +30,24 @@ async def handle_chat_completion( If encoded_base_url is provided, it will be decoded and added to the request. """ + body = await request.body() + data = json.loads(body) if body else {} + + if config.preprocess_chat_request: + data, params = config.preprocess_chat_request(data, request, params) + + project_id = params.project_id + rollout_id = params.rollout_id + invocation_id = params.invocation_id + experiment_id = params.experiment_id + run_id = params.run_id + row_id = params.row_id + encoded_base_url = params.encoded_base_url + # Use default project if not specified if project_id is None: project_id = config.default_project_id - # Read the original request body - body = await request.body() - data = json.loads(body) if body else {} - # Decode and add base_url if provided if encoded_base_url: try: @@ -135,12 +139,11 @@ async def proxy_to_litellm(config: ProxyConfig, path: str, request: Request) -> # Get body body = await request.body() - # For POST/PUT/PATCH with JSON, extract API key from header + # Pass through API key from Authorization header if request.method in ["POST", "PUT", "PATCH"] and body: try: data = json.loads(body) - # Extract API key from Authorization header auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): api_key = auth_header.replace("Bearer ", "").strip() diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py index e77a2328..9bbc92fe 100644 --- a/eval_protocol/proxy/proxy_core/models.py +++ b/eval_protocol/proxy/proxy_core/models.py @@ -3,7 +3,45 @@ """ from pydantic import BaseModel -from typing import Optional, List, Any, Dict +from typing import Optional, List, Any, Dict, Callable, TypeAlias +from fastapi import Request, Query + + +ChatRequestHook: TypeAlias = Callable[[Dict[str, Any], Request, "ChatParams"], tuple[Dict[str, Any], "ChatParams"]] +TracesRequestHook: TypeAlias = Callable[[Request, "TracesParams"], "TracesParams"] + + +class ChatParams(BaseModel): + """Typed container for chat completion URL path parameters.""" + + project_id: Optional[str] = None + rollout_id: Optional[str] = None + invocation_id: Optional[str] = None + experiment_id: Optional[str] = None + run_id: Optional[str] = None + row_id: Optional[str] = None + encoded_base_url: Optional[str] = None + + +class TracesParams(BaseModel): + """Typed container for traces query parameters and controls.""" + + tags: Optional[List[str]] = None + project_id: Optional[str] = None + limit: int = 100 + sample_size: Optional[int] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + name: Optional[str] = None + environment: Optional[str] = None + version: Optional[str] = None + release: Optional[str] = None + fields: Optional[str] = None + hours_back: Optional[int] = None + from_timestamp: Optional[str] = None + to_timestamp: Optional[str] = None + sleep_between_gets: float = 2.5 + max_retries: int = 3 class ProxyConfig(BaseModel): @@ -14,6 +52,8 @@ class ProxyConfig(BaseModel): langfuse_host: str langfuse_keys: Dict[str, Dict[str, str]] default_project_id: str + preprocess_chat_request: Optional[ChatRequestHook] = None + preprocess_traces_request: Optional[TracesRequestHook] = None class ObservationResponse(BaseModel): From 2f5eec67ccb966ed7c089a95e1e0121a75ad3e77 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 00:55:12 -0700 Subject: [PATCH 06/15] better logs --- eval_protocol/proxy/proxy_core/langfuse.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index ce128ee3..688fcb0d 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -236,6 +236,7 @@ async def fetch_langfuse_traces( expected_ids: Set[str] = set() if rollout_id: expected_ids = get_insertion_ids(redis_client, rollout_id) + logger.info(f"Fetching traces for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids") if not expected_ids: logger.warning( f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces." @@ -258,7 +259,9 @@ async def fetch_langfuse_traces( # Build targeted tags for missing insertion_ids missing_ids = expected_ids - insertion_ids fetch_tags = [f"insertion_id:{id}" for id in missing_ids] - logger.info(f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids") + logger.info( + f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}': {[id[:5] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" + ) current_page = 1 collected = 0 @@ -313,6 +316,7 @@ async def fetch_langfuse_traces( insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:") if insertion_id: insertion_ids.add(insertion_id) + logger.debug(f"Found insertion_id '{insertion_id}' for rollout '{rollout_id}'") except Exception as e: logger.warning("Failed to serialize trace %s: %s", trace_info.id, e) @@ -331,8 +335,12 @@ async def fetch_langfuse_traces( # If we have all expected completions or more, return traces. At least once is ok. if expected_ids <= insertion_ids: + logger.info( + f"Traces complete for rollout '{rollout_id}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces" + ) if sample_size is not None and len(all_traces) > sample_size: all_traces = random.sample(all_traces, sample_size) + logger.info(f"Sampled down to {sample_size} traces") return LangfuseTracesResponse( project_id=project_id, @@ -343,8 +351,9 @@ async def fetch_langfuse_traces( # If it doesn't match, wait and do loop again (exponential backoff) if retry < max_retries - 1: wait_time = 2**retry + still_missing = expected_ids - insertion_ids logger.info( - f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} expected. Waiting {wait_time}s..." + f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing: {[id[:5] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." ) await asyncio.sleep(wait_time) From d01acd641a3dd3ea7d12607598181c1c3bb1afe8 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 01:05:01 -0700 Subject: [PATCH 07/15] logs --- eval_protocol/proxy/proxy_core/langfuse.py | 4 ++-- tests/remote_server/test_remote_fireworks.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index 688fcb0d..8c589da4 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -260,7 +260,7 @@ async def fetch_langfuse_traces( missing_ids = expected_ids - insertion_ids fetch_tags = [f"insertion_id:{id}" for id in missing_ids] logger.info( - f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}': {[id[:5] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" + f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}" ) current_page = 1 @@ -353,7 +353,7 @@ async def fetch_langfuse_traces( wait_time = 2**retry still_missing = expected_ids - insertion_ids logger.info( - f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing: {[id[:5] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." + f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..." ) await asyncio.sleep(wait_time) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 3bc41854..d27ace02 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -65,7 +65,6 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", - model_base_url="http://localhost:4000", timeout_seconds=180, output_data_loader=fireworks_output_data_loader, ), From 859dce2af419fc26a8634c1254a75a52d9c246d4 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 01:09:04 -0700 Subject: [PATCH 08/15] remove comment --- eval_protocol/proxy/proxy_core/langfuse.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index 8c589da4..7f13b40c 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -163,11 +163,6 @@ async def fetch_langfuse_traces( This endpoint uses the stored Langfuse keys for the project and polls traces based on the provided filters. - SECURITY: - - Tags are REQUIRED and must not be empty - - At least one tag MUST be in the format 'rollout_id:*' - - This prevents accidentally fetching all traces or traces from other clients - If project_id is not provided, uses the default project. Returns a list of full trace objects (including observations) in JSON format. From 1757548441eb93afd5dc0428b0218637787cdd80 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 01:45:32 -0700 Subject: [PATCH 09/15] Pointwise Mode --- eval_protocol/adapters/fireworks_tracing.py | 8 +- eval_protocol/proxy/README.md | 16 +- eval_protocol/proxy/proxy_core/app.py | 23 ++- eval_protocol/proxy/proxy_core/langfuse.py | 158 ++++++++++++++++++++ 4 files changed, 199 insertions(+), 6 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 816718fe..707f983a 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -343,11 +343,11 @@ def get_evaluation_rows( # Remove None values params = {k: v for k, v in params.items() if v is not None} - # Make request to proxy + # Make request to proxy (using pointwise for efficiency) if self.project_id: - url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" + url = f"{self.base_url}/v1/project_id/{self.project_id}/traces/pointwise" else: - url = f"{self.base_url}/v1/traces" + url = f"{self.base_url}/v1/traces/pointwise" headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} @@ -367,7 +367,7 @@ def get_evaluation_rows( except Exception: # In case e.response.json() fails error_msg = f"Proxy error: {e.response.text}" - logger.error("Failed to fetch traces from proxy: %s", error_msg) + logger.error("Failed to fetch traces from proxy (HTTP %s): %s", e.response.status_code, error_msg) return eval_rows except requests.exceptions.RequestException as e: # Non-HTTP errors (network issues, timeouts, etc.) diff --git a/eval_protocol/proxy/README.md b/eval_protocol/proxy/README.md index dc020890..ffcdaf25 100644 --- a/eval_protocol/proxy/README.md +++ b/eval_protocol/proxy/README.md @@ -166,12 +166,26 @@ The `encoded_base_url` is base64-encoded URL string injected into the request bo ### Trace Fetching -#### Fetch Langfuse Traces +#### Fetch All Langfuse Traces ``` GET /traces?tags=rollout_id:abc123 +GET /v1/traces?tags=rollout_id:abc123 GET /project_id/{project_id}/traces?tags=rollout_id:abc123 +GET /v1/project_id/{project_id}/traces?tags=rollout_id:abc123 ``` +Waits for all expected insertion_ids to complete before returning all traces. + +#### Fetch Latest Langfuse Trace (Pointwise) +``` +GET /traces/pointwise?tags=rollout_id:abc123 +GET /v1/traces/pointwise?tags=rollout_id:abc123 +GET /project_id/{project_id}/traces/pointwise?tags=rollout_id:abc123 +GET /v1/project_id/{project_id}/traces/pointwise?tags=rollout_id:abc123 +``` + +Returns only the latest trace (UUID v7 time-ordered). Much faster for pointwise evaluations where you only need the final accumulated result. + **Required Query Parameters:** - `tags`: Array of tags (must include at least one `rollout_id:*` tag) diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py index 4b76d809..77286109 100644 --- a/eval_protocol/proxy/proxy_core/app.py +++ b/eval_protocol/proxy/proxy_core/app.py @@ -16,7 +16,7 @@ from .models import ProxyConfig, LangfuseTracesResponse, TracesParams, ChatParams, ChatRequestHook, TracesRequestHook from .auth import AuthProvider, NoAuthProvider from .litellm import handle_chat_completion, proxy_to_litellm -from .langfuse import fetch_langfuse_traces +from .langfuse import fetch_langfuse_traces, pointwise_fetch_langfuse_trace # Configure logging before any other imports (so all modules inherit this config) log_level = os.getenv("LOG_LEVEL", "INFO").upper() @@ -267,6 +267,27 @@ async def get_langfuse_traces( params=params, ) + @app.get("/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/project_id/{project_id}/traces/pointwise", response_model=LangfuseTracesResponse) + @app.get("/v1/project_id/{project_id}/traces/pointwise", response_model=LangfuseTracesResponse) + async def pointwise_get_langfuse_trace( + request: Request, + params: TracesParams = Depends(get_traces_params), + project_id: Optional[str] = None, + config: ProxyConfig = Depends(get_config), + redis_client: redis.Redis = Depends(get_redis), + _: None = Depends(require_auth), + ) -> LangfuseTracesResponse: + if project_id is not None: + params.project_id = project_id + return await pointwise_fetch_langfuse_trace( + config=config, + redis_client=redis_client, + request=request, + params=params, + ) + # Health @app.get("/health") async def health(): diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index 7f13b40c..67701971 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -366,3 +366,161 @@ async def fetch_langfuse_traces( raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching traces from Langfuse: {str(e)}") + + +async def pointwise_fetch_langfuse_trace( + config: ProxyConfig, + redis_client: redis.Redis, + request: Request, + params: TracesParams, +): + """ + Fetch the latest trace from Langfuse for the specified project. + + Since insertion_ids are UUID v7 (time-ordered), we only fetch the last one + as it contains all accumulated information from the pointwise evaluation. + + Returns a single trace object or raises if not found. + """ + + # Preprocess traces request + if config.preprocess_traces_request: + params = config.preprocess_traces_request(request, params) + + tags = params.tags + project_id = params.project_id + user_id = params.user_id + session_id = params.session_id + name = params.name + environment = params.environment + version = params.version + release = params.release + fields = params.fields + hours_back = params.hours_back + from_timestamp = params.from_timestamp + to_timestamp = params.to_timestamp + sleep_between_gets = params.sleep_between_gets + max_retries = params.max_retries + + # Use default project if not specified + if project_id is None: + project_id = config.default_project_id + + # Validate project_id + if project_id not in config.langfuse_keys: + raise HTTPException( + status_code=404, + detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}", + ) + + # Extract rollout_id from tags for Redis lookup + rollout_id = _extract_tag_value(tags, "rollout_id:") + + try: + # Import the Langfuse adapter + from langfuse import Langfuse + + # Create Langfuse client with the project's keys + logger.debug(f"Connecting to Langfuse at {config.langfuse_host} for project '{project_id}'") + langfuse_client = Langfuse( + public_key=config.langfuse_keys[project_id]["public_key"], + secret_key=config.langfuse_keys[project_id]["secret_key"], + host=config.langfuse_host, + ) + + # Parse datetime strings if provided + from_ts = None + to_ts = None + if from_timestamp: + from_ts = datetime.fromisoformat(from_timestamp.replace("Z", "+00:00")) + if to_timestamp: + to_ts = datetime.fromisoformat(to_timestamp.replace("Z", "+00:00")) + + # Determine time window: explicit from/to takes precedence over hours_back + if from_ts is None and to_ts is None and hours_back: + to_ts = datetime.now() + from_ts = to_ts - timedelta(hours=hours_back) + + # Get expected insertion_ids from Redis for completeness checking + expected_ids: Set[str] = set() + if rollout_id: + expected_ids = get_insertion_ids(redis_client, rollout_id) + logger.info(f"Pointwise fetch for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids") + if not expected_ids: + logger.warning( + f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty trace." + ) + raise HTTPException( + status_code=500, + detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty trace.", + ) + + # Get the latest (last) insertion_id since UUID v7 is time-ordered + latest_insertion_id = max(expected_ids) # UUID v7 max = newest + logger.info(f"Targeting latest insertion_id (last5): {latest_insertion_id[-5:]} for rollout '{rollout_id}'") + + for retry in range(max_retries): + # Fetch trace list targeting the latest insertion_id + traces = await _fetch_trace_list_with_retry( + langfuse_client, + page=1, + limit=1, # Only need the one trace + tags=[f"insertion_id:{latest_insertion_id}"], + user_id=user_id, + session_id=session_id, + name=name, + environment=environment, + version=version, + release=release, + fields=fields, + from_ts=from_ts, + to_ts=to_ts, + max_retries=max_retries, + ) + + if traces and traces.data: + # Get the trace info + trace_info = traces.data[0] + logger.debug(f"Found trace {trace_info.id} for latest insertion_id {latest_insertion_id[-5:]}") + + # Fetch full trace details + trace_full = await _fetch_trace_detail_with_retry( + langfuse_client, + trace_info.id, + max_retries, + ) + + if trace_full: + trace_dict = _serialize_trace_to_dict(trace_full) + logger.info( + f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]}" + ) + return LangfuseTracesResponse( + project_id=project_id, + total_traces=1, + traces=[TraceResponse(**trace_dict)], + ) + + # If not successful and not last retry, sleep and continue + if retry < max_retries - 1: + wait_time = 2**retry + logger.info( + f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]}. Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + + # After all retries failed + logger.error( + f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]} after {max_retries} retries" + ) + raise HTTPException( + status_code=404, + detail=f"Failed to fetch latest trace for rollout '{rollout_id}' after {max_retries} retries", + ) + + except ImportError: + raise HTTPException(status_code=500, detail="Langfuse SDK not installed. Install with: pip install langfuse") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching latest trace from Langfuse: {str(e)}") From 63a0c656d52f27f185654702ca408307febcee53 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 01:55:38 -0700 Subject: [PATCH 10/15] fix logs --- eval_protocol/proxy/proxy_core/langfuse.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index 67701971..9764e3c9 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -441,23 +441,25 @@ async def pointwise_fetch_langfuse_trace( to_ts = datetime.now() from_ts = to_ts - timedelta(hours=hours_back) - # Get expected insertion_ids from Redis for completeness checking + # Get insertion_ids from Redis to find the latest one expected_ids: Set[str] = set() if rollout_id: expected_ids = get_insertion_ids(redis_client, rollout_id) - logger.info(f"Pointwise fetch for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids") + logger.info( + f"Pointwise fetch for rollout_id '{rollout_id}', found {len(expected_ids)} insertion_ids in Redis" + ) if not expected_ids: logger.warning( - f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty trace." + f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace." ) raise HTTPException( status_code=500, - detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty trace.", + detail=f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace.", ) # Get the latest (last) insertion_id since UUID v7 is time-ordered latest_insertion_id = max(expected_ids) # UUID v7 max = newest - logger.info(f"Targeting latest insertion_id (last5): {latest_insertion_id[-5:]} for rollout '{rollout_id}'") + logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for rollout '{rollout_id}'") for retry in range(max_retries): # Fetch trace list targeting the latest insertion_id @@ -481,7 +483,7 @@ async def pointwise_fetch_langfuse_trace( if traces and traces.data: # Get the trace info trace_info = traces.data[0] - logger.debug(f"Found trace {trace_info.id} for latest insertion_id {latest_insertion_id[-5:]}") + logger.debug(f"Found trace {trace_info.id} for latest insertion_id {latest_insertion_id}") # Fetch full trace details trace_full = await _fetch_trace_detail_with_retry( @@ -493,7 +495,7 @@ async def pointwise_fetch_langfuse_trace( if trace_full: trace_dict = _serialize_trace_to_dict(trace_full) logger.info( - f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]}" + f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id}" ) return LangfuseTracesResponse( project_id=project_id, @@ -505,13 +507,13 @@ async def pointwise_fetch_langfuse_trace( if retry < max_retries - 1: wait_time = 2**retry logger.info( - f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]}. Retrying in {wait_time}s..." + f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) # After all retries failed logger.error( - f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id (last5): {latest_insertion_id[-5:]} after {max_retries} retries" + f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id} after {max_retries} retries" ) raise HTTPException( status_code=404, From e753a0be05950541cdf51cd66ad22190ec1b1571 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Fri, 10 Oct 2025 21:44:48 -0700 Subject: [PATCH 11/15] comments --- eval_protocol/proxy/__init__.py | 3 ++- eval_protocol/proxy/proxy_core/__init__.py | 5 +++-- eval_protocol/proxy/proxy_core/app.py | 8 ++++---- eval_protocol/proxy/proxy_core/auth.py | 7 +++---- eval_protocol/proxy/proxy_core/litellm.py | 1 - eval_protocol/proxy/proxy_core/models.py | 20 ++++++++++++++++---- 6 files changed, 28 insertions(+), 16 deletions(-) diff --git a/eval_protocol/proxy/__init__.py b/eval_protocol/proxy/__init__.py index 93bda257..c471064c 100644 --- a/eval_protocol/proxy/__init__.py +++ b/eval_protocol/proxy/__init__.py @@ -5,7 +5,7 @@ Langfuse tracing for distributed evaluation workflows. """ -from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig, ChatParams, TracesParams +from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig, ChatParams, TracesParams, AccountInfo __all__ = [ "create_app", @@ -14,4 +14,5 @@ "ProxyConfig", "ChatParams", "TracesParams", + "AccountInfo", ] diff --git a/eval_protocol/proxy/proxy_core/__init__.py b/eval_protocol/proxy/proxy_core/__init__.py index d221be71..053f922f 100644 --- a/eval_protocol/proxy/proxy_core/__init__.py +++ b/eval_protocol/proxy/proxy_core/__init__.py @@ -1,11 +1,12 @@ -from .models import ProxyConfig, ChatParams, TracesParams -from .app import create_app +from .models import ProxyConfig, ChatParams, TracesParams, AccountInfo from .auth import AuthProvider, NoAuthProvider +from .app import create_app __all__ = [ "ProxyConfig", "ChatParams", "TracesParams", + "AccountInfo", "create_app", "AuthProvider", "NoAuthProvider", diff --git a/eval_protocol/proxy/proxy_core/app.py b/eval_protocol/proxy/proxy_core/app.py index 77286109..528d467e 100644 --- a/eval_protocol/proxy/proxy_core/app.py +++ b/eval_protocol/proxy/proxy_core/app.py @@ -3,8 +3,8 @@ A FastAPI service that sits in front of LiteLLM and extracts metadata from URL paths. """ -from fastapi import FastAPI, Depends, HTTPException, Request, Query -from typing import Optional, Callable, Dict, Any, List +from fastapi import FastAPI, Depends, Request, Query +from typing import Optional, List import os import redis import logging @@ -166,8 +166,8 @@ def get_traces_params( ) async def require_auth(request: Request) -> None: - account_id = auth_provider.validate_and_return_account_id(request) - request.state.account_id = account_id + account_info = auth_provider.validate_and_return_account_info(request) + request.state.account_id = account_info.account_id if account_info else None return None # ===================== diff --git a/eval_protocol/proxy/proxy_core/auth.py b/eval_protocol/proxy/proxy_core/auth.py index fed24845..cdbf6d3c 100644 --- a/eval_protocol/proxy/proxy_core/auth.py +++ b/eval_protocol/proxy/proxy_core/auth.py @@ -1,18 +1,17 @@ from abc import ABC, abstractmethod import logging from fastapi import Request -from fastapi import HTTPException -import httpx from typing import Optional +from .models import AccountInfo logger = logging.getLogger(__name__) class AuthProvider(ABC): @abstractmethod - def validate_and_return_account_id(self, request: Request) -> Optional[str]: ... + def validate_and_return_account_info(self, request: Request) -> Optional[AccountInfo]: ... class NoAuthProvider(AuthProvider): - def validate_and_return_account_id(self, request: Request) -> Optional[str]: + def validate_and_return_account_info(self, request: Request) -> Optional[AccountInfo]: return None diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py index bade23fb..29b29923 100644 --- a/eval_protocol/proxy/proxy_core/litellm.py +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -8,7 +8,6 @@ import logging from uuid6 import uuid7 from fastapi import Request, Response, HTTPException -from typing import Optional import redis from .redis_utils import register_insertion_id from .models import ProxyConfig, ChatParams diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py index 9bbc92fe..bee7ceeb 100644 --- a/eval_protocol/proxy/proxy_core/models.py +++ b/eval_protocol/proxy/proxy_core/models.py @@ -3,12 +3,24 @@ """ from pydantic import BaseModel -from typing import Optional, List, Any, Dict, Callable, TypeAlias -from fastapi import Request, Query +from typing import Optional, List, Any, Dict, Protocol +from fastapi import Request -ChatRequestHook: TypeAlias = Callable[[Dict[str, Any], Request, "ChatParams"], tuple[Dict[str, Any], "ChatParams"]] -TracesRequestHook: TypeAlias = Callable[[Request, "TracesParams"], "TracesParams"] +class AccountInfo(BaseModel): + """Account information returned from authentication.""" + + account_id: str + + +class ChatRequestHook(Protocol): + def __call__( + self, data: Dict[str, Any], request: Request, params: "ChatParams" + ) -> tuple[Dict[str, Any], "ChatParams"]: ... + + +class TracesRequestHook(Protocol): + def __call__(self, request: Request, params: "TracesParams") -> "TracesParams": ... class ChatParams(BaseModel): From bfe8e3146c3971cadf5c7e43d259b40e7e26163a Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 13 Oct 2025 17:28:46 -0700 Subject: [PATCH 12/15] back to typealias --- eval_protocol/proxy/proxy_core/models.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/eval_protocol/proxy/proxy_core/models.py b/eval_protocol/proxy/proxy_core/models.py index bee7ceeb..f3b5e614 100644 --- a/eval_protocol/proxy/proxy_core/models.py +++ b/eval_protocol/proxy/proxy_core/models.py @@ -3,26 +3,20 @@ """ from pydantic import BaseModel -from typing import Optional, List, Any, Dict, Protocol +from typing import Optional, List, Any, Dict, TypeAlias, Callable from fastapi import Request +ChatRequestHook: TypeAlias = Callable[[Dict[str, Any], Request, "ChatParams"], tuple[Dict[str, Any], "ChatParams"]] +TracesRequestHook: TypeAlias = Callable[[Request, "TracesParams"], "TracesParams"] + + class AccountInfo(BaseModel): """Account information returned from authentication.""" account_id: str -class ChatRequestHook(Protocol): - def __call__( - self, data: Dict[str, Any], request: Request, params: "ChatParams" - ) -> tuple[Dict[str, Any], "ChatParams"]: ... - - -class TracesRequestHook(Protocol): - def __call__(self, request: Request, params: "TracesParams") -> "TracesParams": ... - - class ChatParams(BaseModel): """Typed container for chat completion URL path parameters.""" From ed534b433d8bd61d4f8b7157ae893360d9f7923d Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 13 Oct 2025 17:33:14 -0700 Subject: [PATCH 13/15] export accountinfo --- eval_protocol/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 9cb800e0..2906337d 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -71,10 +71,11 @@ WeaveAdapter = None try: - from .proxy import create_app, AuthProvider + from .proxy import create_app, AuthProvider, AccountInfo except ImportError: create_app = None AuthProvider = None + AccountInfo = None warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") @@ -140,6 +141,7 @@ # Proxy "create_app", "AuthProvider", + "AccountInfo", ] from . import _version From ea7687dddf1b730b8aa84b23d898da280e22fc7f Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 13 Oct 2025 18:00:38 -0700 Subject: [PATCH 14/15] add .env instructions --- eval_protocol/proxy/.env.example | 2 ++ eval_protocol/proxy/docker-compose.yml | 3 +++ 2 files changed, 5 insertions(+) create mode 100644 eval_protocol/proxy/.env.example diff --git a/eval_protocol/proxy/.env.example b/eval_protocol/proxy/.env.example new file mode 100644 index 00000000..1a6eb490 --- /dev/null +++ b/eval_protocol/proxy/.env.example @@ -0,0 +1,2 @@ +# In order to set other model providers keys for proxy, make a copy, rename to .env, and fill here +OPENAI_API_KEY=sk-proj-xxx diff --git a/eval_protocol/proxy/docker-compose.yml b/eval_protocol/proxy/docker-compose.yml index 4128e654..a6058e0e 100644 --- a/eval_protocol/proxy/docker-compose.yml +++ b/eval_protocol/proxy/docker-compose.yml @@ -19,6 +19,9 @@ services: platform: linux/amd64 container_name: litellm-backend command: ["--config", "/app/config.yaml", "--port", "4000", "--host", "0.0.0.0"] + # If you want to be able to use other model providers like OpenAI, Anthropic, etc., you need to set keys in .env file. + env_file: + - .env # Load API keys from .env file environment: - LANGFUSE_PUBLIC_KEY=dummy # Set dummy public and private key so Langfuse instance initializes in LiteLLM, then real keys get sent in proxy - LANGFUSE_SECRET_KEY=dummy From 3c516e0d466d1a1a2d501f7ca0ac6ee7f10cf017 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 13 Oct 2025 20:26:15 -0700 Subject: [PATCH 15/15] only pass fireworks key if it's fireworks model --- eval_protocol/proxy/proxy_core/litellm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/eval_protocol/proxy/proxy_core/litellm.py b/eval_protocol/proxy/proxy_core/litellm.py index 29b29923..cdd2383b 100644 --- a/eval_protocol/proxy/proxy_core/litellm.py +++ b/eval_protocol/proxy/proxy_core/litellm.py @@ -63,7 +63,10 @@ async def handle_chat_completion( auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): api_key = auth_header.replace("Bearer ", "").strip() - data["api_key"] = api_key + # Only inject API key if model is a Fireworks model + model = data.get("model") + if model and isinstance(model, str) and model.startswith("fireworks_ai"): + data["api_key"] = api_key # If metadata IDs are provided, add them as tags insertion_id = None