From 565758653b458eef283246c45db30af9ed7e1a98 Mon Sep 17 00:00:00 2001 From: benjibc Date: Fri, 19 Sep 2025 18:55:57 +0000 Subject: [PATCH 1/2] HTTP remote rollout server support --- eval_protocol/adapters/langfuse.py | 172 +++++++++++++++-- eval_protocol/adapters/langsmith.py | 6 +- eval_protocol/pytest/__init__.py | 2 + .../pytest/remote_rollout_processor.py | 147 +++++++++++++++ tests/chinook/langfuse/remote_server.py | 177 ++++++++++++++++++ .../langfuse/test_remote_langfuse_chinook.py | 142 ++++++++++++++ 6 files changed, 627 insertions(+), 19 deletions(-) create mode 100644 eval_protocol/pytest/remote_rollout_processor.py create mode 100644 tests/chinook/langfuse/remote_server.py create mode 100644 tests/chinook/langfuse/test_remote_langfuse_chinook.py diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 44c43fe2..5825445f 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -4,12 +4,13 @@ to EvaluationRow format for use in evaluation pipelines. """ -from langfuse.api.resources.commons.types.observations_view import ObservationsView +from __future__ import annotations + import logging import random import time from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING from eval_protocol.models import EvaluationRow, InputMetadata, Message from .base import BaseAdapter @@ -46,14 +47,15 @@ def __call__( try: from langfuse import get_client # pyright: ignore[reportPrivateImportUsage] - from langfuse.api.resources.trace.types.traces import Traces - from langfuse.api.resources.commons.types.trace import Trace - from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails LANGFUSE_AVAILABLE = True except ImportError: LANGFUSE_AVAILABLE = False +if TYPE_CHECKING: + from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails + from langfuse.api.resources.commons.types.observations_view import ObservationsView + def convert_trace_to_evaluation_row( trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None @@ -64,7 +66,6 @@ def convert_trace_to_evaluation_row( trace: Langfuse trace object include_tool_calls: Whether to include tool calling information span_name: If provided, extract messages from generations within this named span - converter: Optional custom converter implementing TraceConverter protocol Returns: EvaluationRow or None if conversion fails @@ -97,7 +98,7 @@ def convert_trace_to_evaluation_row( def extract_messages_from_trace( - trace: TraceWithFullDetails, include_tool_calls: bool = True, span_name: Optional[str] = None + trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None ) -> List[Message]: """Extract messages from Langfuse trace input and output. @@ -114,7 +115,7 @@ def extract_messages_from_trace( if span_name: # Look for a generation tied to a span name try: # Find the final generation in the named span - gen: ObservationsView | None = get_final_generation_in_span(trace, span_name) + gen: "ObservationsView | None" = get_final_generation_in_span(trace, span_name) if not gen: return messages @@ -140,10 +141,27 @@ def extract_messages_from_trace( except (AttributeError, ValueError, KeyError) as e: logger.warning("Error processing trace %s: %s", trace.id, e) + # Fallback: use the last GENERATION observation which typically contains full chat history + if not messages: + try: + all_observations = getattr(trace, "observations", None) or [] + gens: List[ObservationsView] = [ + obs for obs in all_observations if getattr(obs, "type", None) == "GENERATION" + ] + if gens: + gens.sort(key=lambda x: x.start_time) + last_gen = gens[-1] + if getattr(last_gen, "input", None): + messages.extend(extract_messages_from_data(getattr(last_gen, "input"), include_tool_calls)) + if getattr(last_gen, "output", None): + messages.extend(extract_messages_from_data(getattr(last_gen, "output"), include_tool_calls)) + except Exception as e: + logger.warning("Failed to extract from last generation for trace %s: %s", trace.id, e) + return messages -def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> ObservationsView | None: +def get_final_generation_in_span(trace: "TraceWithFullDetails", span_name: str) -> "ObservationsView | None": """Get the final generation within a named span that contains full message history. Args: @@ -173,7 +191,7 @@ def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> return None # Find all generations within this span - generations: List[ObservationsView] = [] + generations: List["ObservationsView"] = [] for obs in all_observations: if obs.type == "GENERATION" and obs.parent_observation_id == parent_span.id: generations.append(obs) @@ -241,6 +259,9 @@ def get_evaluation_rows( max_retries: int = 3, span_name: Optional[str] = None, converter: Optional[TraceConverter] = None, + metadata: Optional[Dict[str, Any]] = None, + requester_metadata: Optional[Dict[str, Any]] = None, + requester_metadata_contains: Optional[str] = None, ) -> List[EvaluationRow]: """Pull traces from Langfuse and convert to EvaluationRow format. @@ -275,6 +296,10 @@ def get_evaluation_rows( to_timestamp = datetime.now() from_timestamp = to_timestamp - timedelta(hours=hours_back) + # If filtering by metadata/requester_metadata, prefer fetching metadata fields + if (metadata is not None or requester_metadata is not None or requester_metadata_contains) and not fields: + fields = "core,metadata,observations" + # Collect trace summaries via pagination (up to limit) all_traces = [] page = 1 @@ -354,6 +379,74 @@ def get_evaluation_rows( selected_traces = all_traces logger.debug("Processing all %d collected traces (no sampling)", len(all_traces)) + # Helper to check if a trace matches provided metadata filters. We look in multiple places + # to account for Langfuse moving fields (e.g., metadata vs requester_metadata) and SDK shape. + def _trace_matches_metadata_filters(trace_obj: Any) -> bool: + if metadata is None and requester_metadata is None: + return True + + def _as_dict(val: Any) -> Dict[str, Any]: + if val is None: + return {} + if isinstance(val, dict): + return val + # Some SDK objects expose .model_dump() or behave like pydantic models + dump = getattr(val, "model_dump", None) + if callable(dump): + try: + return dump() # type: ignore[no-any-return] + except Exception: + return {} + return {} + + # Try common locations for metadata on full trace + trace_meta = _as_dict(getattr(trace_obj, "metadata", None)) + trace_req_meta = _as_dict(getattr(trace_obj, "requester_metadata", None)) + # Some Langfuse deployments nest requester_metadata inside metadata + nested_req_meta = {} + try: + if isinstance(trace_meta, dict) and isinstance(trace_meta.get("requester_metadata"), dict): + nested_req_meta = _as_dict(trace_meta.get("requester_metadata")) + except Exception: + nested_req_meta = {} + + # Fallbacks: sometimes metadata is embedded in input + input_meta = {} + try: + inp = getattr(trace_obj, "input", None) + if isinstance(inp, dict): + input_meta = _as_dict(inp.get("metadata")) + except Exception: + input_meta = {} + + # Combine for matching convenience (later keys override earlier for equality check only) + combined_meta = {**trace_meta, **input_meta} + combined_req_meta = {**trace_req_meta} + + # Also merge nested requester metadata when present + if nested_req_meta: + combined_req_meta = {**combined_req_meta, **nested_req_meta} + + def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool: + for k, v in needle.items(): + if haystack.get(k) != v: + return False + return True + + ok_meta = True + ok_req_meta = True + + if metadata is not None: + # Accept match if found either in metadata or requester_metadata buckets + ok_meta = _is_subset(metadata, combined_meta) or _is_subset(metadata, combined_req_meta) + + if requester_metadata is not None: + ok_req_meta = _is_subset(requester_metadata, combined_req_meta) or _is_subset( + requester_metadata, combined_meta + ) + + return ok_meta and ok_req_meta + # Process each selected trace with sleep and retry logic for trace_info in selected_traces: # Sleep between gets to avoid rate limits @@ -365,6 +458,7 @@ def get_evaluation_rows( detail_retries = 0 while detail_retries < max_retries: try: + # Some SDKs don't support fields= on get; call without it trace_full = self.client.api.trace.get(trace_info.id) break except Exception as e: @@ -379,11 +473,49 @@ def get_evaluation_rows( max_retries, ) time.sleep(sleep_time) + elif "Not Found" in str(e) or "404" in str(e): + # Skip missing traces quickly + logger.debug("Trace %s not found, skipping", trace_info.id) + trace_full = None + break else: logger.warning("Failed to fetch trace %s after %d retries: %s", trace_info.id, max_retries, e) break # Skip this trace if trace_full: + # If metadata filters are provided, skip non-matching traces early + try: + if not _trace_matches_metadata_filters(trace_full): + continue + except Exception: + # Be permissive on filter errors; treat as non-match + continue + + # If observations carry requester_metadata, allow substring filtering + if requester_metadata_contains: + contains_val = requester_metadata_contains + found_match = False + try: + for obs in getattr(trace_full, "observations", []) or []: + obs_rmd = getattr(obs, "requester_metadata", None) + if isinstance(obs_rmd, dict) and any( + (isinstance(v, str) and contains_val in v) for v in obs_rmd.values() + ): + found_match = True + break + obs_md = getattr(obs, "metadata", None) + if isinstance(obs_md, dict): + nested = obs_md.get("requester_metadata") + if isinstance(nested, dict) and any( + (isinstance(v, str) and contains_val in v) for v in nested.values() + ): + found_match = True + break + except Exception: + found_match = False + if not found_match: + continue + try: if converter: eval_row = converter(trace_full, include_tool_calls, span_name) @@ -451,16 +583,22 @@ def upload_scores(self, rows: List[EvaluationRow], model_name: str, mean_score: """ try: for trace_id in set( - row.input_metadata.session_data["langfuse_trace_id"] + (row.input_metadata.session_data or {}).get("langfuse_trace_id") for row in rows - if row.evaluation_result and row.input_metadata and row.input_metadata.session_data + if row.input_metadata and row.input_metadata.session_data ): if trace_id: - self.client.create_score( - trace_id=trace_id, - name=model_name, - value=mean_score, - ) + try: + self.client.api.score.create( + trace_id=trace_id, + name=model_name, + value=mean_score, + ) + except Exception: + # Fallback to legacy client if available in some environments + create_score = getattr(self.client, "create_score", None) + if callable(create_score): + create_score(trace_id=trace_id, name=model_name, value=mean_score) except Exception as e: logger.warning("Failed to push scores to Langfuse: %s", e) diff --git a/eval_protocol/adapters/langsmith.py b/eval_protocol/adapters/langsmith.py index fc1daf71..79b23ea7 100644 --- a/eval_protocol/adapters/langsmith.py +++ b/eval_protocol/adapters/langsmith.py @@ -35,10 +35,12 @@ class LangSmithAdapter(BaseAdapter): - outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict] """ - def __init__(self, client: Optional[Client] = None) -> None: + def __init__(self, client: Optional[Any] = None) -> None: if not LANGSMITH_AVAILABLE: raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'") - self.client = client or Client() + # Client is provided by langsmith package; typing is relaxed to Any to avoid + # static analysis issues when stubs aren't available. + self.client = client or Client() # type: ignore[reportCallIssue] def get_evaluation_rows( self, diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index b6d02ae2..35832b6f 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -3,6 +3,7 @@ from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from .default_no_op_rollout_processor import NoOpRolloutProcessor from .default_single_turn_rollout_process import SingleTurnRolloutProcessor +from .remote_rollout_processor import RemoteRolloutProcessor from .evaluation_test import evaluation_test from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config from .rollout_processor import RolloutProcessor @@ -31,6 +32,7 @@ "MCPGymRolloutProcessor", "RolloutProcessor", "SingleTurnRolloutProcessor", + "RemoteRolloutProcessor", "NoOpRolloutProcessor", "default_dataset_adapter", "RolloutProcessorConfig", diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py new file mode 100644 index 00000000..848d9108 --- /dev/null +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -0,0 +1,147 @@ +import asyncio +import time +from typing import Any, Dict, List, Optional + +import requests + +from eval_protocol.models import EvaluationRow +from .rollout_processor import RolloutProcessor +from .types import RolloutProcessorConfig + + +class RemoteRolloutProcessor(RolloutProcessor): + """ + Rollout processor that triggers a remote HTTP server to perform the rollout. + + Expected remote API: + - POST {remote_base_url}/init + Body: { + "rollout_id": str, + "model": str, + "messages": list[dict], + "tools": list[dict] | null, + "metadata": { + "invocation_id": str, + "experiment_id": str, + "rollout_id": str, + "run_id": str | null, + "row_id": str | null + }, + "num_turns": int + } + Returns: {"ok": true} + + - GET {remote_base_url}/status?rollout_id=... + Returns: {"terminated": bool, "info": {...}?} + """ + + def __init__(self): + pass + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + tasks: List[asyncio.Task[EvaluationRow]] = [] + + remote_base_url: Optional[str] = None + num_turns: int = 2 + poll_interval: float = 1.0 + timeout_seconds: float = 120.0 + + if config.kwargs: + remote_base_url = config.kwargs.get("remote_base_url") + num_turns = int(config.kwargs.get("num_turns", num_turns)) + poll_interval = float(config.kwargs.get("poll_interval", poll_interval)) + timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds)) + + if not remote_base_url: + raise ValueError("remote_base_url is required in RolloutProcessorConfig.kwargs for RemoteRolloutProcessor") + + async def _process_row(row: EvaluationRow) -> EvaluationRow: + start_time = time.perf_counter() + + # Build request metadata and payload + meta: Dict[str, Any] = { + "invocation_id": row.execution_metadata.invocation_id, + "experiment_id": row.execution_metadata.experiment_id, + "rollout_id": row.execution_metadata.rollout_id, + "run_id": row.execution_metadata.run_id, + "row_id": row.input_metadata.row_id, + } + + model: Optional[str] = None + if row.input_metadata and row.input_metadata.completion_params: + model = row.input_metadata.completion_params.get("model") + if model is None and config.completion_params: + model = config.completion_params.get("model") + if model is None: + raise ValueError( + "Model must be provided in row.input_metadata.completion_params or config.completion_params" + ) + + # Strip non-OpenAI fields from messages before sending to remote + allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"} + clean_messages = [] + for m in row.messages: + md: Dict[str, Any] + if hasattr(m, "model_dump"): + md = m.model_dump() # type: ignore[assignment] + elif isinstance(m, dict): + md = m # type: ignore[assignment] + else: + # Fallback to constructing a dict from Message-like object + md = { + "role": getattr(m, "role", None), + "content": getattr(m, "content", None), + "tool_calls": getattr(m, "tool_calls", None), + "tool_call_id": getattr(m, "tool_call_id", None), + "name": getattr(m, "name", None), + } + clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None}) + + init_payload: Dict[str, Any] = { + "rollout_id": row.execution_metadata.rollout_id, + "model": model, + "messages": clean_messages, + "tools": row.tools, + "metadata": meta, + "num_turns": num_turns, + } + + # Fire-and-poll + def _post_init() -> None: + url = f"{remote_base_url}/init" + r = requests.post(url, json=init_payload, timeout=30) + r.raise_for_status() + + await asyncio.to_thread(_post_init) + + terminated = False + deadline = time.time() + timeout_seconds + + def _get_status() -> Dict[str, Any]: + url = f"{remote_base_url}/status" + r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) + r.raise_for_status() + return r.json() + + while time.time() < deadline: + try: + status = await asyncio.to_thread(_get_status) + terminated = bool(status.get("terminated", False)) + if terminated: + break + except Exception: + # transient errors; continue polling + pass + await asyncio.sleep(poll_interval) + + # Update duration, regardless of termination + row.execution_metadata.duration_seconds = time.perf_counter() - start_time + return row + + for r in rows: + tasks.append(asyncio.create_task(_process_row(r))) + + return tasks + + def cleanup(self) -> None: + return None diff --git a/tests/chinook/langfuse/remote_server.py b/tests/chinook/langfuse/remote_server.py new file mode 100644 index 00000000..36cdae20 --- /dev/null +++ b/tests/chinook/langfuse/remote_server.py @@ -0,0 +1,177 @@ +import os +import threading +from typing import Any, Dict + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import requests + + +app = FastAPI() + + +class InitRequest(BaseModel): + rollout_id: str + model: str + messages: list[dict] + tools: list[dict] | None = None + metadata: dict + num_turns: int = 2 + + +_STATE: Dict[str, Dict[str, Any]] = {} + + +ALLOWED_MESSAGE_FIELDS = {"role", "content", "tool_calls", "tool_call_id", "name"} + + +def _clean_messages_for_api(messages: list[dict]) -> list[dict]: + cleaned: list[dict] = [] + for msg in messages: + if not isinstance(msg, dict): + continue + cm = {k: v for k, v in msg.items() if k in ALLOWED_MESSAGE_FIELDS and v is not None} + # Some providers dislike empty content on assistant messages; keep if present + cleaned.append(cm) + return cleaned + + +@app.post("/init") +def init(req: InitRequest): + # Persist state + _STATE[req.rollout_id] = {"terminated": False} + + # Kick off worker thread that runs multi-turn chat via LiteLLM proxy + def _worker(): + try: + # Try to set up Langfuse trace to guarantee observability, independent of proxy wiring + langfuse = None + trace = None + try: + from langfuse import get_client # pyright: ignore[reportPrivateImportUsage] + + langfuse = get_client() + id_tags = [] + try: + id_tags = [ + f"inv:{req.metadata.get('invocation_id')}", + f"exp:{req.metadata.get('experiment_id')}", + f"rollout:{req.metadata.get('rollout_id')}", + ] + except Exception: + id_tags = [] + trace = langfuse.api.trace.create( + name="remote_chinook_rollout", + metadata=req.metadata, + requester_metadata=req.metadata, + tags=["chinook_remote", "chinook_sql", *[t for t in id_tags if t]], + input={ + "messages": _clean_messages_for_api(req.messages), + "tools": req.tools, + "metadata": req.metadata, + }, + ) + except Exception: + langfuse = None + trace = None + + base_url = os.getenv( + "LITELLM_BASE_URL", + "https://litellm-cloud-proxy-prod-644257448872.us-central1.run.app", + ) + url = f"{base_url}/v1/chat/completions" + headers = { + "Authorization": f"Bearer {os.getenv('FIREWORKS_API_KEY', '')}", + "Content-Type": "application/json", + } + + # Prepare metadata payload to attach for Langfuse filtering + metadata = { + "invocation_id": req.metadata.get("invocation_id"), + "experiment_id": req.metadata.get("experiment_id"), + "rollout_id": req.metadata.get("rollout_id"), + "run_id": req.metadata.get("run_id"), + "row_id": req.metadata.get("row_id"), + } + + messages = req.messages + + # Simulate N-1 assistant turns (single-shot or simple echo) + for _ in range(max(1, req.num_turns)): + payload = { + "model": req.model, + "messages": _clean_messages_for_api(messages), + "metadata": metadata, + } + if req.tools: + payload["tools"] = req.tools + r = requests.post(url, json=payload, headers=headers, timeout=60) + r.raise_for_status() + data = r.json() + assistant = data.get("choices", [{}])[0].get("message", {}) + # Optionally record a generation on Langfuse + try: + if langfuse and trace and getattr(langfuse.api, "generation", None): + langfuse.api.generation.create( + trace_id=trace.id, + name="assistant", + input={"messages": _clean_messages_for_api(messages)}, + output=assistant, + ) + except Exception: + pass + # Append assistant for next turn + messages = messages + [assistant] + + # Update final trace output for easier adapter extraction + try: + if langfuse and trace: + langfuse.api.trace.update( + id=trace.id, + output={ + "messages": _clean_messages_for_api(messages), + "metadata": req.metadata, + }, + ) + except Exception: + pass + + except Exception: + # Best-effort; mark as done even on error to unblock polling + pass + finally: + try: + if "langfuse" in locals() and langfuse is not None: + # Ensure buffered telemetry is sent + flush = getattr(langfuse, "flush", None) + if callable(flush): + flush() + shutdown = getattr(langfuse, "shutdown", None) + if callable(shutdown): + shutdown() + except Exception: + pass + _STATE[req.rollout_id]["terminated"] = True + + t = threading.Thread(target=_worker, daemon=True) + t.start() + return {"ok": True} + + +@app.get("/status") +def status(rollout_id: str): + st = _STATE.get(rollout_id) + if not st: + raise HTTPException(status_code=404, detail="unknown rollout_id") + return {"terminated": bool(st.get("terminated", False))} + + +def main(): + host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") + port = int(os.getenv("REMOTE_SERVER_PORT", "7077")) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + main() diff --git a/tests/chinook/langfuse/test_remote_langfuse_chinook.py b/tests/chinook/langfuse/test_remote_langfuse_chinook.py new file mode 100644 index 00000000..daf20715 --- /dev/null +++ b/tests/chinook/langfuse/test_remote_langfuse_chinook.py @@ -0,0 +1,142 @@ +import os +import multiprocessing +import time +from datetime import datetime, timedelta +from typing import List +import atexit + +import pytest + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + + +def _start_remote_server(): + # Starts FastAPI server defined in remote_server.py using absolute import + import importlib + + os.environ.setdefault("REMOTE_SERVER_HOST", "127.0.0.1") + os.environ.setdefault("REMOTE_SERVER_PORT", "7077") + mod = importlib.import_module("tests.chinook.langfuse.remote_server") + mod.main() + + +def _ensure_server_running(): + # Launch in a background process + proc = multiprocessing.Process(target=_start_remote_server, daemon=True) + proc.start() + # Give it a moment to boot + time.sleep(1.5) + return proc + + +# Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body) +_SERVER_PROC = _ensure_server_running() +atexit.register(lambda: (_SERVER_PROC.terminate() if _SERVER_PROC.is_alive() else None)) + + +def _make_input_rows() -> List[EvaluationRow]: + # Minimal single-user-turn message to trigger a response + row = EvaluationRow(messages=[Message(role="user", content="Hello there! Please say hi back.")]) + return [row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.asyncio +@evaluation_test( + input_rows=[_make_input_rows()], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], + rollout_processor=RemoteRolloutProcessor(), + rollout_processor_kwargs={ + "remote_base_url": "http://127.0.0.1:7077", + "num_turns": 2, + "timeout_seconds": 30, + }, + mode="pointwise", +) +async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - remote server started at import time + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse filtered by metadata; FAIL if none found + """ + # Debug print IDs used for filtering + print( + "[Remote-E2E] IDs:", + { + "invocation_id": row.execution_metadata.invocation_id, + "experiment_id": row.execution_metadata.experiment_id, + "rollout_id": row.execution_metadata.rollout_id, + "run_id": row.execution_metadata.run_id, + }, + ) + + # Attempt retrieval via adapter + try: + from eval_protocol.adapters.langfuse import create_langfuse_adapter + + adapter = create_langfuse_adapter() + + # Preferred: observations-level requester_metadata contains invocation_id (proxy annotates per-request) + contains_val = row.execution_metadata.invocation_id or "" + rows = [] + if contains_val: + # Retry loop to allow ingestion/flush + deadline = time.time() + 90 + while time.time() < deadline and not rows: + rows = adapter.get_evaluation_rows( + limit=10, + from_timestamp=datetime.now() - timedelta(hours=2), + to_timestamp=datetime.now(), + include_tool_calls=False, + requester_metadata_contains=contains_val, + ) + if rows: + break + time.sleep(3) + else: + print("[Remote-E2E] Missing invocation_id; skipping observations filter") + + # If still empty, dump recent trace metadata for debugging + if not rows: + try: + from langfuse import get_client # pyright: ignore[reportPrivateImportUsage] + + lf = get_client() + recent = lf.api.trace.list(limit=5, order_by="timestamp.desc") + print("[Remote-E2E] Recent trace metadata dump (id, metadata, requester_metadata, tags):") + if recent and getattr(recent, "data", None): + for t in recent.data: + try: + full = lf.api.trace.get(t.id) + print( + { + "id": full.id, + "metadata": getattr(full, "metadata", None), + "requester_metadata": getattr(full, "requester_metadata", None), + "tags": getattr(full, "tags", None), + } + ) + except Exception as e: + print("[Remote-E2E] Failed to get trace details:", e) + else: + print("[Remote-E2E] No recent traces found via list().") + except Exception as e: + print("[Remote-E2E] Langfuse debug fetch failed:", e) + + assert rows and len(rows) > 0, ( + "No Langfuse traces matched the metadata. Ensure the LiteLLM proxy is configured to forward " + "Langfuse telemetry and that LANGFUSE_* env vars are set." + ) + + # Minimal sanity: rows contain session_data.langfuse_trace_id + assert any((r.input_metadata.session_data or {}).get("langfuse_trace_id") for r in rows), ( + "Expected langfuse_trace_id in session_data for at least one row" + ) + + except ImportError: + pytest.fail("Langfuse SDK not installed; cannot verify traces.") + + return row From 9be55309a87659bd6e9146b846fcbe63fa8f132f Mon Sep 17 00:00:00 2001 From: benjibc Date: Sat, 20 Sep 2025 16:36:10 +0000 Subject: [PATCH 2/2] clean up --- .../pytest/remote_rollout_processor.py | 29 ++++++-- tests/chinook/langfuse/remote_server.py | 66 ------------------- .../langfuse/test_remote_langfuse_chinook.py | 37 ++++++++--- 3 files changed, 50 insertions(+), 82 deletions(-) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 848d9108..04963f0d 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -35,19 +35,34 @@ class RemoteRolloutProcessor(RolloutProcessor): Returns: {"terminated": bool, "info": {...}?} """ - def __init__(self): - pass + def __init__( + self, + *, + remote_base_url: Optional[str] = None, + num_turns: int = 2, + poll_interval: float = 1.0, + timeout_seconds: float = 120.0, + ): + # Prefer constructor-provided configuration. These can be overridden via + # config.kwargs at call time for backward compatibility. + self._remote_base_url = remote_base_url + self._num_turns = num_turns + self._poll_interval = poll_interval + self._timeout_seconds = timeout_seconds def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: tasks: List[asyncio.Task[EvaluationRow]] = [] - remote_base_url: Optional[str] = None - num_turns: int = 2 - poll_interval: float = 1.0 - timeout_seconds: float = 120.0 + # Start with constructor values + remote_base_url: Optional[str] = self._remote_base_url + num_turns: int = self._num_turns + poll_interval: float = self._poll_interval + timeout_seconds: float = self._timeout_seconds + # Backward compatibility: allow overrides via config.kwargs if config.kwargs: - remote_base_url = config.kwargs.get("remote_base_url") + if remote_base_url is None: + remote_base_url = config.kwargs.get("remote_base_url", remote_base_url) num_turns = int(config.kwargs.get("num_turns", num_turns)) poll_interval = float(config.kwargs.get("poll_interval", poll_interval)) timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds)) diff --git a/tests/chinook/langfuse/remote_server.py b/tests/chinook/langfuse/remote_server.py index 36cdae20..71971378 100644 --- a/tests/chinook/langfuse/remote_server.py +++ b/tests/chinook/langfuse/remote_server.py @@ -45,37 +45,6 @@ def init(req: InitRequest): # Kick off worker thread that runs multi-turn chat via LiteLLM proxy def _worker(): try: - # Try to set up Langfuse trace to guarantee observability, independent of proxy wiring - langfuse = None - trace = None - try: - from langfuse import get_client # pyright: ignore[reportPrivateImportUsage] - - langfuse = get_client() - id_tags = [] - try: - id_tags = [ - f"inv:{req.metadata.get('invocation_id')}", - f"exp:{req.metadata.get('experiment_id')}", - f"rollout:{req.metadata.get('rollout_id')}", - ] - except Exception: - id_tags = [] - trace = langfuse.api.trace.create( - name="remote_chinook_rollout", - metadata=req.metadata, - requester_metadata=req.metadata, - tags=["chinook_remote", "chinook_sql", *[t for t in id_tags if t]], - input={ - "messages": _clean_messages_for_api(req.messages), - "tools": req.tools, - "metadata": req.metadata, - }, - ) - except Exception: - langfuse = None - trace = None - base_url = os.getenv( "LITELLM_BASE_URL", "https://litellm-cloud-proxy-prod-644257448872.us-central1.run.app", @@ -110,48 +79,13 @@ def _worker(): r.raise_for_status() data = r.json() assistant = data.get("choices", [{}])[0].get("message", {}) - # Optionally record a generation on Langfuse - try: - if langfuse and trace and getattr(langfuse.api, "generation", None): - langfuse.api.generation.create( - trace_id=trace.id, - name="assistant", - input={"messages": _clean_messages_for_api(messages)}, - output=assistant, - ) - except Exception: - pass # Append assistant for next turn messages = messages + [assistant] - # Update final trace output for easier adapter extraction - try: - if langfuse and trace: - langfuse.api.trace.update( - id=trace.id, - output={ - "messages": _clean_messages_for_api(messages), - "metadata": req.metadata, - }, - ) - except Exception: - pass - except Exception: # Best-effort; mark as done even on error to unblock polling pass finally: - try: - if "langfuse" in locals() and langfuse is not None: - # Ensure buffered telemetry is sent - flush = getattr(langfuse, "flush", None) - if callable(flush): - flush() - shutdown = getattr(langfuse, "shutdown", None) - if callable(shutdown): - shutdown() - except Exception: - pass _STATE[req.rollout_id]["terminated"] = True t = threading.Thread(target=_worker, daemon=True) diff --git a/tests/chinook/langfuse/test_remote_langfuse_chinook.py b/tests/chinook/langfuse/test_remote_langfuse_chinook.py index daf20715..1b3951c9 100644 --- a/tests/chinook/langfuse/test_remote_langfuse_chinook.py +++ b/tests/chinook/langfuse/test_remote_langfuse_chinook.py @@ -6,6 +6,7 @@ import atexit import pytest +import requests from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import evaluation_test @@ -23,17 +24,36 @@ def _start_remote_server(): def _ensure_server_running(): + host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") + port = int(os.getenv("REMOTE_SERVER_PORT", "7077")) + base_url = f"http://{host}:{port}" + + def _is_up() -> bool: + try: + r = requests.get(f"{base_url}/status", params={"rollout_id": "ping"}, timeout=1.0) + return r.status_code in (200, 404) + except Exception: + return False + + if _is_up(): + return None + # Launch in a background process proc = multiprocessing.Process(target=_start_remote_server, daemon=True) proc.start() - # Give it a moment to boot - time.sleep(1.5) + + # Poll for readiness up to 10s + deadline = time.time() + 10 + while time.time() < deadline: + if _is_up(): + break + time.sleep(0.5) return proc # Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body) _SERVER_PROC = _ensure_server_running() -atexit.register(lambda: (_SERVER_PROC.terminate() if _SERVER_PROC.is_alive() else None)) +atexit.register(lambda: (_SERVER_PROC and _SERVER_PROC.is_alive() and _SERVER_PROC.terminate())) def _make_input_rows() -> List[EvaluationRow]: @@ -47,12 +67,11 @@ def _make_input_rows() -> List[EvaluationRow]: @evaluation_test( input_rows=[_make_input_rows()], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], - rollout_processor=RemoteRolloutProcessor(), - rollout_processor_kwargs={ - "remote_base_url": "http://127.0.0.1:7077", - "num_turns": 2, - "timeout_seconds": 30, - }, + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:7077", + num_turns=2, + timeout_seconds=30, + ), mode="pointwise", ) async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: