From 493c64a0b77ba06945342cf3453fb0819b519a87 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 17 Sep 2025 15:55:25 -0700 Subject: [PATCH 1/2] Add tests for adapters --- tests/adapters/test_braintrust_adapter.py | 374 ++++++++++++++++++++++ tests/adapters/test_langfuse_adapter.py | 332 +++++++++++++++++++ tests/adapters/test_langsmith_adapter.py | 19 +- tests/test_quickstart_utils.py | 298 +++++++++++++++++ 4 files changed, 1013 insertions(+), 10 deletions(-) create mode 100644 tests/adapters/test_braintrust_adapter.py create mode 100644 tests/adapters/test_langfuse_adapter.py create mode 100644 tests/test_quickstart_utils.py diff --git a/tests/adapters/test_braintrust_adapter.py b/tests/adapters/test_braintrust_adapter.py new file mode 100644 index 00000000..fd1dcce4 --- /dev/null +++ b/tests/adapters/test_braintrust_adapter.py @@ -0,0 +1,374 @@ +import json +import os +from types import SimpleNamespace +from typing import Any, Dict, List +from unittest.mock import Mock + +import pytest +import requests + +from eval_protocol.adapters.braintrust import BraintrustAdapter +from eval_protocol.models import Message + + +class MockResponse: + """Mock response object for requests.post""" + + def __init__(self, json_data: Dict[str, Any], status_code: int = 200): + self.json_data = json_data + self.status_code = status_code + + def json(self) -> Dict[str, Any]: # noqa: F811 + return self.json_data + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise requests.HTTPError(f"HTTP {self.status_code}") + + +@pytest.fixture +def mock_requests_post(monkeypatch): + """Mock requests.post to return sample data""" + + def fake_post(url: str, headers=None, json_data=None): + # Return a simplified response for basic tests + return MockResponse( + { + "data": [ + { + "id": "trace1", + "input": [{"role": "user", "content": "Hello"}], + "output": [{"message": {"role": "assistant", "content": "Hi there!"}}], + } + ] + } + ) + + monkeypatch.setattr("requests.post", fake_post) + return fake_post + + +def test_basic_btql_query_returns_evaluation_rows(mock_requests_post): + """Test basic BTQL query execution and conversion to evaluation rows""" + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + + btql_query = "select: * from: project_logs('test_project') traces limit: 1" + rows = adapter.get_evaluation_rows(btql_query) + + assert len(rows) == 1 + assert len(rows[0].messages) == 2 + assert rows[0].messages[0].role == "user" + assert rows[0].messages[0].content == "Hello" + assert rows[0].messages[1].role == "assistant" + assert rows[0].messages[1].content == "Hi there!" + + +def test_trace_with_tool_calls_preserved(monkeypatch): + """Test that tool calls are properly preserved in converted messages""" + + def mock_post(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + { + "id": "trace_with_tools", + "input": [{"role": "user", "content": "Get reservation details for 7KJ2PL"}], + "output": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_reservation_details", + "arguments": '{"reservation_id": "7KJ2PL"}', + }, + } + ], + } + } + ], + } + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 1 + msgs = rows[0].messages + + # Find assistant message with tool calls + assistant_msgs = [m for m in msgs if m.role == "assistant" and m.tool_calls] + assert len(assistant_msgs) == 1 + + assert assistant_msgs[0].tool_calls is not None + tool_call = assistant_msgs[0].tool_calls[0] + assert tool_call.id == "call_123" + assert tool_call.function.name == "get_reservation_details" + assert '{"reservation_id": "7KJ2PL"}' in tool_call.function.arguments + + +def test_trace_with_tool_response_messages(monkeypatch): + """Test that tool response messages are properly handled""" + + def mock_post(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + { + "id": "trace_with_tool_response", + "input": [ + {"role": "user", "content": "Check reservation"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "function": { + "name": "get_reservation_details", + "arguments": '{"reservation_id": "ABC123"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_456", + "content": '{"reservation_id": "ABC123", "status": "confirmed"}', + }, + ], + "output": [ + {"message": {"role": "assistant", "content": "Your reservation ABC123 is confirmed."}} + ], + } + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 1 + msgs = rows[0].messages + + # Should have user, assistant with tool_calls, tool response, and final assistant + roles = [m.role for m in msgs] + assert "user" in roles + assert "tool" in roles + assert roles.count("assistant") == 2 # One with tool_calls, one final response + + # Check tool message + tool_msgs = [m for m in msgs if m.role == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0].tool_call_id == "call_456" + assert tool_msgs[0].content is not None + assert "ABC123" in tool_msgs[0].content + + +def test_tools_extracted_from_metadata_variants(monkeypatch): + """Test that tools are extracted from different metadata locations""" + + def mock_post_with_tools_in_metadata(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + { + "id": "trace_with_metadata_tools", + "input": [{"role": "user", "content": "Test"}], + "output": [{"message": {"role": "assistant", "content": "Response"}}], + "metadata": { + "tools": [ + { + "type": "function", + "function": {"name": "get_weather", "description": "Get weather info"}, + } + ] + }, + } + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post_with_tools_in_metadata) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 1 + assert rows[0].tools is not None + assert len(rows[0].tools) == 1 + assert rows[0].tools[0]["function"]["name"] == "get_weather" + + +def test_tools_extracted_from_hidden_params(monkeypatch): + """Test that tools are extracted from nested hidden_params location""" + + def mock_post_with_hidden_tools(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + { + "id": "trace_with_hidden_tools", + "input": [{"role": "user", "content": "Test"}], + "output": [{"message": {"role": "assistant", "content": "Response"}}], + "metadata": { + "hidden_params": { + "optional_params": { + "tools": [ + { + "type": "function", + "function": { + "name": "transfer_to_human_agents", + "description": "Transfer to human", + }, + } + ] + } + } + }, + } + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post_with_hidden_tools) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 1 + assert rows[0].tools is not None + assert len(rows[0].tools) == 1 + assert rows[0].tools[0]["function"]["name"] == "transfer_to_human_agents" + + +def test_empty_btql_response_returns_empty_list(monkeypatch): + """Test that empty BTQL response returns empty list""" + + def mock_empty_post(url: str, headers=None, json_data=None): + return MockResponse({"data": []}) + + monkeypatch.setattr("requests.post", mock_empty_post) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 0 + + +def test_trace_without_meaningful_conversation_skipped(monkeypatch): + """Test that traces without input or output are skipped""" + + def mock_post_incomplete_trace(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + {"id": "incomplete_trace", "input": None, "output": []}, + { + "id": "valid_trace", + "input": [{"role": "user", "content": "Hello"}], + "output": [{"message": {"role": "assistant", "content": "Hi"}}], + }, + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post_incomplete_trace) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + # Should only get the valid trace + assert len(rows) == 1 + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["braintrust_trace_id"] == "valid_trace" + + +def test_custom_converter_used_when_provided(monkeypatch): + """Test that custom converter is used when provided""" + + def mock_post(url: str, headers=None, json_data=None): + return MockResponse( + { + "data": [ + { + "id": "custom_trace", + "input": [{"role": "user", "content": "Test"}], + "output": [{"message": {"role": "assistant", "content": "Response"}}], + } + ] + } + ) + + monkeypatch.setattr("requests.post", mock_post) + + def custom_converter(trace: Dict[str, Any], include_tool_calls: bool): + # Custom converter that adds a special message + from eval_protocol.models import EvaluationRow, InputMetadata + + return EvaluationRow( + messages=[Message(role="system", content="Custom converted message")], + input_metadata=InputMetadata(session_data={"custom": True}), + ) + + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query", converter=custom_converter) + + assert len(rows) == 1 + assert rows[0].messages[0].role == "system" + assert rows[0].messages[0].content == "Custom converted message" + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["custom"] is True + + +def test_api_authentication_error_handling(monkeypatch): + """Test that API authentication errors are handled properly""" + + def mock_auth_error(url: str, headers=None, json_data=None): + return MockResponse({}, status_code=401) + + monkeypatch.setattr("requests.post", mock_auth_error) + + adapter = BraintrustAdapter(api_key="invalid_key", project_id="test_project") + + with pytest.raises(requests.HTTPError): + adapter.get_evaluation_rows("test query") + + +def test_session_data_includes_trace_id(mock_requests_post): + """Test that session_data includes the Braintrust trace ID""" + adapter = BraintrustAdapter(api_key="test_key", project_id="test_project") + rows = adapter.get_evaluation_rows("test query") + + assert len(rows) == 1 + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["braintrust_trace_id"] == "trace1" + + +def test_missing_required_env_vars(monkeypatch): + """Test that missing required environment variables raise errors""" + # Mock environment variables to be None + monkeypatch.setenv("BRAINTRUST_API_KEY", "") + monkeypatch.setenv("BRAINTRUST_PROJECT_ID", "") + + # Test missing API key + with pytest.raises(ValueError, match="BRAINTRUST_API_KEY"): + BraintrustAdapter(api_key=None, project_id="test_project") + + # Test missing project ID + with pytest.raises(ValueError, match="BRAINTRUST_PROJECT_ID"): + BraintrustAdapter(api_key="test_key", project_id=None) diff --git a/tests/adapters/test_langfuse_adapter.py b/tests/adapters/test_langfuse_adapter.py new file mode 100644 index 00000000..3a4c2699 --- /dev/null +++ b/tests/adapters/test_langfuse_adapter.py @@ -0,0 +1,332 @@ +"""Tests for Langfuse adapter.""" + +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest + +from eval_protocol.adapters.langfuse import ( + LangfuseAdapter, + convert_trace_to_evaluation_row, + extract_messages_from_trace, +) +from eval_protocol.models import EvaluationRow, InputMetadata, Message + + +class FakeLangfuseClient: + """Mock Langfuse client for testing""" + + def __init__(self, traces_list_response, trace_details): + self.traces_list_response = traces_list_response + self.trace_details = trace_details + + @property + def api(self): + return FakeLangfuseAPI(self.traces_list_response, self.trace_details) + + def create_score(self, trace_id: str, name: str, value: float): + """Mock score creation""" + pass + + +class FakeLangfuseAPI: + """Mock Langfuse API for testing""" + + def __init__(self, traces_list_response, trace_details): + self.traces_list_response = traces_list_response + self.trace_details = trace_details + + @property + def trace(self): + return FakeLangfuseTraceAPI(self.traces_list_response, self.trace_details) + + +class FakeLangfuseTraceAPI: + """Mock Langfuse trace API for testing""" + + def __init__(self, traces_list_response, trace_details): + self.traces_list_response = traces_list_response + self.trace_details = trace_details + + def list(self, **kwargs): + """Mock trace list method""" + return self.traces_list_response + + def get(self, trace_id: str): + """Mock trace get method""" + return self.trace_details.get(trace_id, self.trace_details.get("default")) + + +def _create_mock_trace( + trace_id: str, input_data: Any = None, output_data: Any = None, observations: Optional[List] = None +): + """Helper to create mock trace objects""" + return SimpleNamespace(id=trace_id, input=input_data, output=output_data, observations=observations or []) + + +def _create_mock_traces_response(traces: List[Dict[str, Any]]): + """Helper to create mock traces list response""" + trace_objects = [] + for trace_data in traces: + trace_objects.append(SimpleNamespace(**trace_data)) + + return SimpleNamespace(data=trace_objects, meta=SimpleNamespace(page=1, total_pages=1)) + + +@pytest.fixture +def mock_langfuse_client(monkeypatch): + """Mock the Langfuse client""" + + def fake_get_client(): + traces_response = _create_mock_traces_response([{"id": "trace1", "name": "test_trace"}]) + trace_details = { + "default": _create_mock_trace( + "trace1", + input_data={"messages": [{"role": "user", "content": "Hello"}]}, + output_data={"messages": [{"role": "assistant", "content": "Hi there!"}]}, + ) + } + return FakeLangfuseClient(traces_response, trace_details) + + monkeypatch.setattr("eval_protocol.adapters.langfuse.get_client", fake_get_client) + return fake_get_client + + +def test_basic_trace_conversion(): + """Test basic trace to evaluation row conversion""" + trace = _create_mock_trace( + "trace123", + input_data={"messages": [{"role": "user", "content": "What's the weather?"}]}, + output_data={"messages": [{"role": "assistant", "content": "It's sunny!"}]}, + ) + + result = convert_trace_to_evaluation_row(trace) # pyright: ignore[reportArgumentType] + + assert result is not None + assert len(result.messages) == 2 + assert result.messages[0].role == "user" + assert result.messages[0].content == "What's the weather?" + assert result.messages[1].role == "assistant" + assert result.messages[1].content == "It's sunny!" + assert result.input_metadata is not None + assert result.input_metadata.session_data is not None + assert result.input_metadata.session_data["langfuse_trace_id"] == "trace123" + + +def test_trace_with_tool_calls(): + """Test trace conversion with tool calls""" + trace = _create_mock_trace( + "trace_tools", + input_data={ + "messages": [{"role": "user", "content": "Get weather for NYC"}], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info"}}], + }, + output_data={ + "messages": [ + { + "role": "assistant", + "content": "I'll check the weather for you.", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "NYC"}'}, + } + ], + } + ] + }, + ) + + result = convert_trace_to_evaluation_row(trace, include_tool_calls=True) # pyright: ignore[reportArgumentType] + + assert result is not None + assert result.tools is not None + assert len(result.tools) == 1 + assert result.tools[0]["function"]["name"] == "get_weather" + + # Check that tool calls are preserved in messages + assistant_msgs = [m for m in result.messages if m.role == "assistant" and m.tool_calls] + assert len(assistant_msgs) == 1 + # Tool calls are converted to OpenAI format objects, not dicts + assert assistant_msgs[0].tool_calls is not None + tool_call = assistant_msgs[0].tool_calls[0] + assert hasattr(tool_call, "function") + assert tool_call.function.name == "get_weather" + + +def test_trace_conversion_with_span_name(): + """Test trace conversion with specific span name""" + # Mock observations with spans and generations + observations = [ + SimpleNamespace(id="span1", name="judge", type="SPAN"), + SimpleNamespace( + id="gen1", + name="generation", + type="GENERATION", + parent_observation_id="span1", + input={"messages": [{"role": "user", "content": "Judge this"}]}, + output={"messages": [{"role": "assistant", "content": "Good response"}]}, + start_time=datetime.now(), + ), + ] + + trace = _create_mock_trace("trace_span", observations=observations) + result = convert_trace_to_evaluation_row(trace, span_name="judge") # pyright: ignore[reportArgumentType] + + assert result is not None + assert len(result.messages) == 2 + assert result.messages[0].content == "Judge this" + assert result.messages[1].content == "Good response" + + +def test_empty_trace_returns_none(): + """Test that empty traces return None""" + trace = _create_mock_trace("empty_trace", input_data=None, output_data=None) + + result = convert_trace_to_evaluation_row(trace) # pyright: ignore[reportArgumentType] + + assert result is None + + +def test_malformed_trace_returns_none(): + """Test that malformed traces are handled gracefully""" + # Trace with missing required attributes + trace = SimpleNamespace(id="malformed") # Missing input/output + + result = convert_trace_to_evaluation_row(trace) # pyright: ignore[reportArgumentType] + + assert result is None + + +def test_langfuse_adapter_initialization(mock_langfuse_client): + """Test LangfuseAdapter initialization""" + adapter = LangfuseAdapter() + assert adapter.client is not None + + +def test_langfuse_adapter_unavailable(): + """Test that ImportError is raised when Langfuse is not available""" + import eval_protocol.adapters.langfuse as langfuse_module + + # Temporarily set LANGFUSE_AVAILABLE to False + original_available = langfuse_module.LANGFUSE_AVAILABLE + langfuse_module.LANGFUSE_AVAILABLE = False + + try: + with pytest.raises(ImportError, match="Langfuse not installed"): + LangfuseAdapter() + finally: + # Restore original value + langfuse_module.LANGFUSE_AVAILABLE = original_available + + +def test_get_evaluation_rows_basic(mock_langfuse_client): + """Test basic get_evaluation_rows functionality""" + adapter = LangfuseAdapter() + + rows = adapter.get_evaluation_rows(limit=1) + + assert len(rows) == 1 + assert rows[0].messages[0].role == "user" + assert rows[0].messages[0].content == "Hello" + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["langfuse_trace_id"] == "trace1" + + +def test_get_evaluation_rows_by_ids(mock_langfuse_client): + """Test get_evaluation_rows_by_ids functionality""" + adapter = LangfuseAdapter() + + rows = adapter.get_evaluation_rows_by_ids(["trace1"]) + + assert len(rows) == 1 + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["langfuse_trace_id"] == "trace1" + + +def test_get_evaluation_rows_by_ids_with_custom_converter(mock_langfuse_client): + """Test get_evaluation_rows_by_ids with custom converter""" + + def custom_converter(trace, include_tool_calls: bool, span_name: Optional[str]): + return EvaluationRow( + messages=[Message(role="system", content="Custom converted message")], + input_metadata=InputMetadata(session_data={"custom": True, "langfuse_trace_id": trace.id}), + ) + + adapter = LangfuseAdapter() + rows = adapter.get_evaluation_rows_by_ids(["trace1"], converter=custom_converter) + + assert len(rows) == 1 + assert rows[0].messages[0].role == "system" + assert rows[0].messages[0].content == "Custom converted message" + assert rows[0].input_metadata is not None + assert rows[0].input_metadata.session_data is not None + assert rows[0].input_metadata.session_data["custom"] is True + + +def test_sampling_functionality(monkeypatch): + """Test that sampling works correctly""" + + def fake_get_client(): + # Create multiple traces + traces_response = _create_mock_traces_response([{"id": f"trace{i}", "name": f"trace_{i}"} for i in range(10)]) + trace_details = { + f"trace{i}": _create_mock_trace( + f"trace{i}", + input_data={"messages": [{"role": "user", "content": f"Message {i}"}]}, + output_data={"messages": [{"role": "assistant", "content": f"Response {i}"}]}, + ) + for i in range(10) + } + trace_details["default"] = trace_details["trace0"] + + return FakeLangfuseClient(traces_response, trace_details) + + monkeypatch.setattr("eval_protocol.adapters.langfuse.get_client", fake_get_client) + + adapter = LangfuseAdapter() + rows = adapter.get_evaluation_rows(limit=10, sample_size=3) + + # Should get exactly 3 rows due to sampling + assert len(rows) == 3 + + +def test_extract_messages_from_various_formats(): + """Test message extraction from different input formats""" + # Test dict format with messages + trace1 = _create_mock_trace( + "trace1", + input_data={"messages": [{"role": "user", "content": "Hello"}]}, + output_data={"messages": [{"role": "assistant", "content": "Hi"}]}, + ) + messages1 = extract_messages_from_trace(trace1) + assert len(messages1) == 2 + assert messages1[0].role == "user" + assert messages1[1].role == "assistant" + + # Test simple prompt format + trace2 = _create_mock_trace( + "trace2", input_data={"prompt": "What is AI?"}, output_data={"content": "AI is artificial intelligence"} + ) + messages2 = extract_messages_from_trace(trace2) + assert len(messages2) == 2 + assert messages2[0].role == "user" + assert messages2[0].content == "What is AI?" + assert messages2[1].role == "assistant" + assert messages2[1].content == "AI is artificial intelligence" + + # Test list format + trace3 = _create_mock_trace( + "trace3", + input_data=[{"role": "user", "content": "List format"}], + output_data=[{"role": "assistant", "content": "Response"}], + ) + messages3 = extract_messages_from_trace(trace3) + assert len(messages3) == 2 + assert messages3[0].content == "List format" + assert messages3[1].content == "Response" diff --git a/tests/adapters/test_langsmith_adapter.py b/tests/adapters/test_langsmith_adapter.py index 2f32282a..7956b980 100644 --- a/tests/adapters/test_langsmith_adapter.py +++ b/tests/adapters/test_langsmith_adapter.py @@ -37,7 +37,7 @@ def test_outputs_messages_preferred_and_dedup_user(): }, ) ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) assert len(rows) == 1 msgs = rows[0].messages @@ -53,7 +53,7 @@ def test_inputs_variants_prompt_user_input_input(): SimpleNamespace(id="p3", inputs={"input": "C"}, outputs={"answer": "OC"}), SimpleNamespace(id="p4", inputs="D", outputs="OD"), ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) texts = [[(m.role, m.content) for m in r.messages] for r in rows] assert ("user", "A") in texts[0] @@ -71,7 +71,7 @@ def test_outputs_variants_and_list_payloads(): SimpleNamespace(id="o1", inputs=[], outputs={"output": "X"}), SimpleNamespace(id="o2", inputs=[_msg("user", "U")], outputs=[_msg("assistant", "V")]), ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) msgs1 = rows[0].messages assert any(m.role == "assistant" and m.content == "X" for m in msgs1) @@ -108,14 +108,15 @@ def test_tool_calls_and_tool_role_preserved(): }, ) ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) msgs = rows[0].messages # Ensure tool role present - assert any(m.role == "tool" and (m.content or "").strip() == "5" for m in msgs) + assert any(m.role == "tool" and str(m.content or "").strip() == "5" for m in msgs) # Ensure assistant with tool_calls preserved assistants = [m for m in msgs if m.role == "assistant" and m.tool_calls] assert len(assistants) >= 1 + assert assistants[0].tool_calls is not None tc = assistants[0].tool_calls[0] # tool_calls may be provider-native objects; normalize via getattr first fname = None @@ -141,7 +142,7 @@ def test_system_prompt_first_and_multiple_user_allowed(): outputs={"content": "hello there"}, ) ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) msgs = rows[0].messages roles = [m.role for m in msgs] @@ -170,14 +171,12 @@ def test_parallel_tool_calls_normalized(): outputs={"messages": [assistant_with_tools]}, ), ] - adapter = LangSmithAdapter(client=FakeClient(runs)) + adapter = LangSmithAdapter(client=FakeClient(runs)) # pyright: ignore[reportArgumentType] rows = adapter.get_evaluation_rows(project_name="p", limit=10) msgs = rows[0].messages assistants = [m for m in msgs if m.role == "assistant" and m.tool_calls] assert len(assistants) == 1 tcs = assistants[0].tool_calls assert isinstance(tcs, list) and len(tcs) == 2 - names = [ - (getattr(tc, "function").name if hasattr(tc, "function") else tc.get("function", {}).get("name")) for tc in tcs - ] + names = [getattr(tc.function, "name", None) if hasattr(tc, "function") else None for tc in tcs] assert names == ["calculator_add", "calculator_add"] diff --git a/tests/test_quickstart_utils.py b/tests/test_quickstart_utils.py new file mode 100644 index 00000000..860d1d35 --- /dev/null +++ b/tests/test_quickstart_utils.py @@ -0,0 +1,298 @@ +"""Tests for quickstart utility functions.""" + +import pytest + +from eval_protocol.models import EvaluationRow, InputMetadata, Message +from eval_protocol.quickstart.utils import split_multi_turn_rows, serialize_message + + +class TestSerializeMessage: + """Tests for serialize_message function.""" + + def test_simple_message(self): + """Test serialization of a simple message.""" + message = Message(role="user", content="Hello, how are you?") + result = serialize_message(message) + assert result == "user: Hello, how are you?" + + def test_assistant_message(self): + """Test serialization of an assistant message.""" + message = Message(role="assistant", content="I'm doing well, thank you!") + result = serialize_message(message) + assert result == "assistant: I'm doing well, thank you!" + + def test_message_with_tool_calls(self): + """Test serialization of a message with tool calls.""" + tool_call = { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "New York"}'}, + } + message = Message( + role="assistant", + content="I'll check the weather for you.", + tool_calls=[tool_call], # pyright: ignore[reportArgumentType] + ) + result = serialize_message(message) + expected = 'assistant: I\'ll check the weather for you.\n[Tool Call: get_weather({"location": "New York"})]' + assert result == expected + + def test_message_with_multiple_tool_calls(self): + """Test serialization of a message with multiple tool calls.""" + tool_call1 = { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "NYC"}'}, + } + tool_call2 = { + "id": "call_456", + "type": "function", + "function": {"name": "get_time", "arguments": '{"timezone": "EST"}'}, + } + message = Message( + role="assistant", + content="Let me get both for you.", + tool_calls=[tool_call1, tool_call2], # pyright: ignore[reportArgumentType] + ) + result = serialize_message(message) + expected = ( + "assistant: Let me get both for you.\n" + '[Tool Call: get_weather({"location": "NYC"})]\n' + '[Tool Call: get_time({"timezone": "EST"})]' + ) + assert result == expected + + def test_empty_content_message(self): + """Test serialization of a message with empty content.""" + message = Message(role="assistant", content="") + result = serialize_message(message) + assert result == "assistant: " + + def test_none_content_message(self): + """Test serialization of a message with None content.""" + message = Message(role="assistant", content=None) + result = serialize_message(message) + assert result == "assistant: None" + + +class TestSplitMultiTurnRows: + """Tests for split_multi_turn_rows function.""" + + def test_single_turn_conversation(self): + """Test that single-turn conversations are handled correctly.""" + messages = [ + Message(role="user", content="What's the weather like?"), + Message(role="assistant", content="It's sunny today!"), + ] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 1 + assert len(result[0].messages) == 1 # Only user message before assistant + assert result[0].messages[0].role == "user" + assert result[0].messages[0].content == "What's the weather like?" + assert result[0].ground_truth == "assistant: It's sunny today!" + + def test_multi_turn_conversation(self): + """Test that multi-turn conversations are split correctly.""" + messages = [ + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + Message(role="user", content="How are you?"), + Message(role="assistant", content="I'm doing well, thanks!"), + ] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 2 + + # First split: user -> assistant + assert len(result[0].messages) == 1 + assert result[0].messages[0].content == "Hello" + assert result[0].ground_truth == "assistant: Hi there!" + + # Second split: user -> assistant -> user -> assistant + assert len(result[1].messages) == 3 + assert result[1].messages[0].content == "Hello" + assert result[1].messages[1].content == "Hi there!" + assert result[1].messages[2].content == "How are you?" + assert result[1].ground_truth == "assistant: I'm doing well, thanks!" + + def test_conversation_with_system_message(self): + """Test that system messages are preserved in splits.""" + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + Message(role="user", content="How are you?"), + Message(role="assistant", content="I'm doing well!"), + ] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 2 + + # First split should include system message + assert len(result[0].messages) == 2 + assert result[0].messages[0].role == "system" + assert result[0].messages[1].role == "user" + + # Second split should include system message and previous conversation + assert len(result[1].messages) == 4 + assert result[1].messages[0].role == "system" + + def test_conversation_with_tool_calls(self): + """Test that tool calls are preserved in ground truth.""" + tool_call = { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "NYC"}'}, + } + messages = [ + Message(role="user", content="What's the weather in NYC?"), + Message( + role="assistant", + content="I'll check that for you.", + tool_calls=[tool_call], # pyright: ignore[reportArgumentType] + ), + ] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 1 + expected_ground_truth = 'assistant: I\'ll check that for you.\n[Tool Call: get_weather({"location": "NYC"})]' + assert result[0].ground_truth == expected_ground_truth + + def test_multiple_rows_processing(self): + """Test that multiple input rows are processed correctly.""" + row1 = EvaluationRow( + messages=[Message(role="user", content="Hello"), Message(role="assistant", content="Hi!")] + ) + row2 = EvaluationRow( + messages=[Message(role="user", content="Goodbye"), Message(role="assistant", content="Bye!")] + ) + + result = split_multi_turn_rows([row1, row2]) + + assert len(result) == 2 + assert result[0].messages[0].content == "Hello" + assert result[0].ground_truth == "assistant: Hi!" + assert result[1].messages[0].content == "Goodbye" + assert result[1].ground_truth == "assistant: Bye!" + + def test_no_assistant_messages(self): + """Test that rows with no assistant messages return empty list.""" + messages = [Message(role="user", content="Hello"), Message(role="user", content="Anyone there?")] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 0 + + def test_only_assistant_messages(self): + """Test handling of rows with only assistant messages.""" + messages = [Message(role="assistant", content="Hello!"), Message(role="assistant", content="How can I help?")] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 2 + # First assistant message (no context) + assert len(result[0].messages) == 0 + assert result[0].ground_truth == "assistant: Hello!" + # Second assistant message (with first assistant as context) + assert len(result[1].messages) == 1 + assert result[1].messages[0].content == "Hello!" + assert result[1].ground_truth == "assistant: How can I help?" + + def test_duplicate_trace_filtering(self): + """Test that duplicate traces are filtered out.""" + # Create two rows with the same conversation leading to different assistant responses + messages1 = [ + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + Message(role="user", content="How are you?"), + Message(role="assistant", content="I'm good!"), + ] + messages2 = [ + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + Message(role="user", content="How are you?"), + Message(role="assistant", content="I'm great!"), # Different response + ] + + row1 = EvaluationRow(messages=messages1) + row2 = EvaluationRow(messages=messages2) + + result = split_multi_turn_rows([row1, row2]) + + # Should only get 2 unique splits (not 4), because the context leading + # to the second assistant message is the same in both rows + assert len(result) == 2 # First "Hello" -> "Hi there!", then one unique context for second assistant + + # Verify the unique traces + contexts = ["\n".join(serialize_message(m) for m in r.messages) for r in result] + assert len(set(contexts)) == len(contexts) # All contexts should be unique + + def test_tools_and_metadata_preservation(self): + """Test that tools and input_metadata are preserved in split rows.""" + tools = [{"type": "function", "function": {"name": "test_tool"}}] + input_metadata = InputMetadata( + row_id="test_row", completion_params={"model": "gpt-4"}, session_data={"test": "data"} + ) + + messages = [Message(role="user", content="Hello"), Message(role="assistant", content="Hi!")] + row = EvaluationRow(messages=messages, tools=tools, input_metadata=input_metadata) + + result = split_multi_turn_rows([row]) + + assert len(result) == 1 + assert result[0].tools == tools + assert result[0].input_metadata == input_metadata + + def test_empty_input_list(self): + """Test that empty input list returns empty result.""" + result = split_multi_turn_rows([]) + assert len(result) == 0 + + def test_complex_multi_turn_with_tool_responses(self): + """Test complex conversation with tool calls and responses.""" + tool_call = { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "NYC"}'}, + } + + messages = [ + Message(role="user", content="What's the weather in NYC?"), + Message( + role="assistant", + content="I'll check that for you.", + tool_calls=[tool_call], # pyright: ignore[reportArgumentType] + ), + Message(role="tool", tool_call_id="call_123", content="Sunny, 75°F"), + Message(role="assistant", content="It's sunny and 75°F in NYC!"), + Message(role="user", content="Thanks!"), + Message(role="assistant", content="You're welcome!"), + ] + row = EvaluationRow(messages=messages) + + result = split_multi_turn_rows([row]) + + assert len(result) == 3 # Three assistant messages + + # First assistant message with tool call + assert len(result[0].messages) == 1 # Just user message + assert "Tool Call: get_weather" in str(result[0].ground_truth or "") + + # Second assistant message after tool response + assert len(result[1].messages) == 3 # user, assistant with tool call, tool response + assert result[1].ground_truth == "assistant: It's sunny and 75°F in NYC!" + + # Third assistant message + assert len(result[2].messages) == 5 # All previous messages + "Thanks!" + assert result[2].ground_truth == "assistant: You're welcome!" From 4e8bfddf8231e43eec732118e33a10703350c018 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 17 Sep 2025 16:05:27 -0700 Subject: [PATCH 2/2] fix test --- tests/adapters/test_braintrust_adapter.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/adapters/test_braintrust_adapter.py b/tests/adapters/test_braintrust_adapter.py index fd1dcce4..97bf1cfa 100644 --- a/tests/adapters/test_braintrust_adapter.py +++ b/tests/adapters/test_braintrust_adapter.py @@ -1,4 +1,3 @@ -import json import os from types import SimpleNamespace from typing import Any, Dict, List @@ -18,7 +17,7 @@ def __init__(self, json_data: Dict[str, Any], status_code: int = 200): self.json_data = json_data self.status_code = status_code - def json(self) -> Dict[str, Any]: # noqa: F811 + def json(self) -> Dict[str, Any]: return self.json_data def raise_for_status(self) -> None: @@ -30,7 +29,7 @@ def raise_for_status(self) -> None: def mock_requests_post(monkeypatch): """Mock requests.post to return sample data""" - def fake_post(url: str, headers=None, json_data=None): + def fake_post(url: str, headers=None, json=None): # Return a simplified response for basic tests return MockResponse( { @@ -66,7 +65,7 @@ def test_basic_btql_query_returns_evaluation_rows(mock_requests_post): def test_trace_with_tool_calls_preserved(monkeypatch): """Test that tool calls are properly preserved in converted messages""" - def mock_post(url: str, headers=None, json_data=None): + def mock_post(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -118,7 +117,7 @@ def mock_post(url: str, headers=None, json_data=None): def test_trace_with_tool_response_messages(monkeypatch): """Test that tool response messages are properly handled""" - def mock_post(url: str, headers=None, json_data=None): + def mock_post(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -179,7 +178,7 @@ def mock_post(url: str, headers=None, json_data=None): def test_tools_extracted_from_metadata_variants(monkeypatch): """Test that tools are extracted from different metadata locations""" - def mock_post_with_tools_in_metadata(url: str, headers=None, json_data=None): + def mock_post_with_tools_in_metadata(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -214,7 +213,7 @@ def mock_post_with_tools_in_metadata(url: str, headers=None, json_data=None): def test_tools_extracted_from_hidden_params(monkeypatch): """Test that tools are extracted from nested hidden_params location""" - def mock_post_with_hidden_tools(url: str, headers=None, json_data=None): + def mock_post_with_hidden_tools(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -256,7 +255,7 @@ def mock_post_with_hidden_tools(url: str, headers=None, json_data=None): def test_empty_btql_response_returns_empty_list(monkeypatch): """Test that empty BTQL response returns empty list""" - def mock_empty_post(url: str, headers=None, json_data=None): + def mock_empty_post(url: str, headers=None, json=None): return MockResponse({"data": []}) monkeypatch.setattr("requests.post", mock_empty_post) @@ -270,7 +269,7 @@ def mock_empty_post(url: str, headers=None, json_data=None): def test_trace_without_meaningful_conversation_skipped(monkeypatch): """Test that traces without input or output are skipped""" - def mock_post_incomplete_trace(url: str, headers=None, json_data=None): + def mock_post_incomplete_trace(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -299,7 +298,7 @@ def mock_post_incomplete_trace(url: str, headers=None, json_data=None): def test_custom_converter_used_when_provided(monkeypatch): """Test that custom converter is used when provided""" - def mock_post(url: str, headers=None, json_data=None): + def mock_post(url: str, headers=None, json=None): return MockResponse( { "data": [ @@ -337,7 +336,7 @@ def custom_converter(trace: Dict[str, Any], include_tool_calls: bool): def test_api_authentication_error_handling(monkeypatch): """Test that API authentication errors are handled properly""" - def mock_auth_error(url: str, headers=None, json_data=None): + def mock_auth_error(url: str, headers=None, json=None): return MockResponse({}, status_code=401) monkeypatch.setattr("requests.post", mock_auth_error)