Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
except ImportError:
LangSmithAdapter = None


try:
from .adapters import WeaveAdapter
except ImportError:
WeaveAdapter = None

warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

__all__ = [
Expand Down
7 changes: 7 additions & 0 deletions eval_protocol/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@
__all__.extend(["LangSmithAdapter"])
except ImportError:
pass

try:
from .weave import WeaveAdapter

__all__.extend(["WeaveAdapter"])
except ImportError:
pass
130 changes: 130 additions & 0 deletions eval_protocol/adapters/weave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Weave (Weights & Biases) adapter for Eval Protocol.

This adapter fetches recent root traces from Weave Trace API and converts them
to `EvaluationRow` format for use in evaluation pipelines. It is intentionally
minimal and depends only on requests.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional
import os
import requests

from eval_protocol.models import EvaluationRow, InputMetadata, Message, ExecutionMetadata
from .base import BaseAdapter


def _extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]:
messages: List[Message] = []

# Prefer explicit output messages if provided
output = trace.get("output") or {}
out_msgs = output.get("messages")
if isinstance(out_msgs, list):
for m in out_msgs:
messages.append(
Message(
role=m.get("role"),
content=m.get("content"),
tool_calls=m.get("tool_calls") if include_tool_calls else None,
tool_call_id=m.get("tool_call_id"),
name=m.get("name"),
)
)

# If no explicit output messages, fall back to final bubble from choices
if not messages:
choices = output.get("choices")
if isinstance(choices, list) and choices:
msg = (choices[0] or {}).get("message", {})
if msg:
messages.append(Message(role=msg.get("role"), content=msg.get("content")))

# Prepend input messages if present and not already contained
inputs = trace.get("inputs") or {}
in_msgs = inputs.get("messages")
if isinstance(in_msgs, list):
prefixed = [Message(role=m.get("role"), content=m.get("content")) for m in in_msgs]
messages = prefixed + messages

return messages


def _convert_trace_to_evaluation_row(
trace: Dict[str, Any], include_tool_calls: bool = True
) -> Optional[EvaluationRow]:
messages = _extract_messages_from_trace(trace, include_tool_calls=include_tool_calls)
if not messages:
return None

# Provider-native IDs for UI joinability
session_data = {
"weave_trace_id": trace.get("id"),
"weave_project_id": trace.get("project_id"),
}

# Optional EP identifiers (if present in provider payload)
meta_in = (trace.get("inputs") or {}).get("metadata") or {}
meta_out = (trace.get("output") or {}).get("metadata") or {}
metadata = {**meta_in, **meta_out}

input_metadata = InputMetadata(row_id=metadata.get("row_id"), session_data=session_data)

# Preserve default factory behavior by only setting provided fields
exec_kwargs: Dict[str, Any] = {}
for k in ("invocation_id", "experiment_id", "rollout_id", "run_id"):
if metadata.get(k) is not None:
exec_kwargs[k] = metadata[k]
execution_metadata = ExecutionMetadata(**exec_kwargs)

# Capture tools if provider exposes them (prefer inputs)
tools = None
inputs = trace.get("inputs") or {}
if include_tool_calls and isinstance(inputs, dict) and "tools" in inputs:
tools = inputs.get("tools")

return EvaluationRow(
messages=messages, tools=tools, input_metadata=input_metadata, execution_metadata=execution_metadata
)


class WeaveAdapter(BaseAdapter):
"""Adapter to pull data from Weave Trace API and convert to EvaluationRow format."""

def __init__(
self, base_url: Optional[str] = None, api_token: Optional[str] = None, project_id: Optional[str] = None
):
self.base_url = base_url or os.getenv("WEAVE_TRACE_BASE_URL", "https://trace.wandb.ai")
self.api_token = api_token or os.getenv("WANDB_API_KEY")
# project_id is in form "<entity>/<project>"
self.project_id = project_id or (f"{os.getenv('WANDB_ENTITY')}/{os.getenv('WANDB_PROJECT')}")
if not self.api_token or not self.project_id or "/" not in self.project_id:
raise ValueError("Missing Weave credentials or project (WANDB_API_KEY and WANDB_ENTITY/WANDB_PROJECT)")

def _fetch_traces(self, limit: int = 100) -> List[Dict[str, Any]]:
url = f"{self.base_url}/calls/stream_query"
payload = {
"project_id": self.project_id,
"filter": {"trace_roots_only": True},
"limit": limit,
"offset": 0,
"sort_by": [{"field": "started_at", "direction": "desc"}],
"include_feedback": False,
}
headers = {"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"}
resp = requests.post(url, json=payload, headers=headers, timeout=30)
resp.raise_for_status()
body = resp.json() or {}
return body.get("data", [])

def get_evaluation_rows(self, *args, **kwargs) -> List[EvaluationRow]:
limit = kwargs.get("limit", 100)
include_tool_calls = kwargs.get("include_tool_calls", True)
traces = self._fetch_traces(limit=limit)
rows: List[EvaluationRow] = []
for tr in traces:
row = _convert_trace_to_evaluation_row(tr, include_tool_calls=include_tool_calls)
if row:
rows.append(row)
return rows
16 changes: 16 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,19 @@ The typical lifecycle of working with or developing an example involves these ke
2. Model your structure and documentation after `examples/math_example/`.
3. Ensure your example has its own clear `README.md` and necessary `conf/` files.
4. Test thoroughly.

## Tracing provider IO references

Provider-specific IO references (input logging + output pulling) live under:

- `examples/tracing/<provider>/`

Current providers:

- `examples/tracing/weave/`: Input/Output reference for Weave (W&B) tracing

Each provider folder includes:

- `produce_input_trace.py`: Minimal script to log a chat completion
- `pull_output_traces.py`: Script to fetch traces and convert to `EvaluationRow`
- `converter.py`: Provider-to-EP message+metadata mapping
10 changes: 10 additions & 0 deletions examples/adapters/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ Loads datasets from HuggingFace Hub and converts them to EvaluationRow format.
pip install 'eval-protocol[huggingface]'
```

## Tracing provider IO references

Provider-specific IO references (input logging + output pulling) have moved under:

- `examples/tracing/<provider>/`

For Weave, see `examples/tracing/weave/` which contains a focused `converter.py` illustrating how to map provider payloads to EP messages and metadata.

These examples are designed to be self-contained and usable as references for building or validating provider adapters.

## Running the Examples

### Basic Usage
Expand Down
Empty file added examples/tracing/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions examples/tracing/weave/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Weave (Weights & Biases) tracing examples.

This package contains a focused `converter.py` that illustrates how to map
Weave provider payloads to Eval Protocol `EvaluationRow` objects. Use it as a
reference when building or validating provider adapters.
"""
78 changes: 78 additions & 0 deletions examples/tracing/weave/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, Dict, List, Optional

from eval_protocol.models import EvaluationRow, InputMetadata, Message, ExecutionMetadata


def _extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]:
messages: List[Message] = []

# Prefer explicit output messages if provided
output = trace.get("output") or {}
out_msgs = output.get("messages")
if isinstance(out_msgs, list):
for m in out_msgs:
messages.append(
Message(
role=m.get("role"),
content=m.get("content"),
tool_calls=m.get("tool_calls") if include_tool_calls else None,
tool_call_id=m.get("tool_call_id"),
name=m.get("name"),
)
)

# If no explicit output messages, fall back to final bubble from choices
if not messages:
choices = output.get("choices")
if isinstance(choices, list) and choices:
msg = (choices[0] or {}).get("message", {})
if msg:
messages.append(Message(role=msg.get("role"), content=msg.get("content")))

# Prepend input messages if present and not already contained
inputs = trace.get("inputs") or {}
in_msgs = inputs.get("messages")
if isinstance(in_msgs, list):
prefixed = [Message(role=m.get("role"), content=m.get("content")) for m in in_msgs]
messages = prefixed + messages

return messages


def convert_trace_to_evaluation_row(trace: Dict[str, Any], include_tool_calls: bool = True) -> Optional[EvaluationRow]:
messages = _extract_messages_from_trace(trace, include_tool_calls=include_tool_calls)
if not messages:
return None

# Provider-native IDs for UI joinability
session_data = {
"weave_trace_id": trace.get("id"),
"weave_project_id": trace.get("project_id"),
}

# Optional EP identifiers (if present in provider payload)
meta_in = (trace.get("inputs") or {}).get("metadata") or {}
meta_out = (trace.get("output") or {}).get("metadata") or {}
metadata = {**meta_in, **meta_out}

input_metadata = InputMetadata(row_id=metadata.get("row_id"), session_data=session_data)

# Preserve default factory behavior by only setting provided fields
exec_kwargs: Dict[str, Any] = {}
for k in ("invocation_id", "experiment_id", "rollout_id", "run_id"):
if metadata.get(k) is not None:
exec_kwargs[k] = metadata[k]
execution_metadata = ExecutionMetadata(**exec_kwargs)

# Capture tools if provider exposes them (prefer inputs)
tools = None
inputs = trace.get("inputs") or {}
if include_tool_calls and isinstance(inputs, dict) and "tools" in inputs:
tools = inputs.get("tools")

return EvaluationRow(
messages=messages,
tools=tools,
input_metadata=input_metadata,
execution_metadata=execution_metadata,
)
61 changes: 61 additions & 0 deletions tests/adapters/test_weave_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import importlib.util
from pathlib import Path

import pytest


def _load_module_from_path(name: str, path: str):
spec = importlib.util.spec_from_file_location(name, path)
assert spec and spec.loader, f"Failed to load module spec for {name} from {path}"
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[attr-defined]
return mod


@pytest.mark.skip(reason="Weave example only: converter IO smoke-test placeholder (no live fetch script).")
def test_weave_converter_basic_messages():
root = Path(__file__).resolve().parents[2]
converter_path = root / "examples" / "tracing" / "weave" / "converter.py"
mod = _load_module_from_path("weave_converter", str(converter_path))
convert = getattr(mod, "convert_trace_to_evaluation_row")

trace = {
"id": "tr_123",
"project_id": "team/proj",
"inputs": {"messages": [{"role": "user", "content": "Hi"}]},
"output": {"choices": [{"message": {"role": "assistant", "content": "Hello"}}]},
}

row = convert(trace)
assert len(row.messages) >= 1
assert row.input_metadata.session_data.get("weave_trace_id") == "tr_123"


@pytest.mark.skip(reason="Credential-gated live fetch; enable locally with WANDB creds.")
def test_weave_fetch_and_convert_live():
# Require explicit env to avoid CI failures
if not os.getenv("WANDB_API_KEY"):
pytest.skip("WANDB_API_KEY not set")

team = os.getenv("WANDB_ENTITY") or os.getenv("WEAVE_TEAM_ID")
project = os.getenv("WANDB_PROJECT") or os.getenv("WEAVE_PROJECT_ID")
if not team or not project:
pytest.skip("Weave project not configured")

base_url = os.getenv("WEAVE_TRACE_BASE_URL", "https://trace.wandb.ai")
root = Path(__file__).resolve().parents[2]
pull_path = root / "examples" / "tracing" / "weave" / "pull_output_traces.py"
conv_path = root / "examples" / "tracing" / "weave" / "converter.py"

pull_mod = _load_module_from_path("weave_pull", str(pull_path))
conv_mod = _load_module_from_path("weave_converter", str(conv_path))

fetch_weave_traces = getattr(pull_mod, "fetch_weave_traces")
convert = getattr(conv_mod, "convert_trace_to_evaluation_row")

traces = fetch_weave_traces(
base_url=base_url, project_id=f"{team}/{project}", api_token=os.environ["WANDB_API_KEY"], limit=1
)
rows = [convert(tr) for tr in traces]
assert any(r is not None for r in rows)
Loading