|
1 | 1 | from eval_protocol.models import EvaluateResult, EvaluationRow, Message |
2 | 2 | from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test |
| 3 | +from openai import AsyncOpenAI |
| 4 | +import json |
| 5 | +from pydantic import BaseModel |
| 6 | +import logging |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | +import os |
| 10 | + |
| 11 | + |
| 12 | +class ResponseFormat(BaseModel): |
| 13 | + score: float |
3 | 14 |
|
4 | 15 |
|
5 | 16 | @evaluation_test( |
6 | | - input_messages=[ |
7 | | - [ |
8 | | - [ |
9 | | - Message( |
10 | | - role="system", |
11 | | - content=( |
12 | | - "You are a helpful assistant that can answer questions about Gmail. You have access to tools to help you find information.\n" |
13 | | - ), |
14 | | - ), |
15 | | - Message( |
16 | | - role="user", |
17 | | - content=("Find the first 5 emails title in my inbox."), |
18 | | - ), |
19 | | - ] |
20 | | - ] |
21 | | - ], |
| 17 | + input_dataset=["tests/pytest/datasets/gmail_inbox.jsonl"], |
22 | 18 | rollout_processor=AgentRolloutProcessor(), |
23 | 19 | completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], |
24 | 20 | mode="pointwise", |
25 | 21 | mcp_config_path="tests/pytest/mcp_configurations/klavis_strata_mcp.json", |
26 | 22 | ) |
27 | | -def test_pytest_klavis_mcp(row: EvaluationRow) -> EvaluationRow: |
28 | | - # filter for all tool calls |
29 | | - tool_calls = [msg for msg in row.messages if msg.role == "tool"] |
30 | | - if len(tool_calls) == 0: |
| 23 | +async def test_pytest_klavis_mcp(row: EvaluationRow) -> EvaluationRow: |
| 24 | + ground_truth = row.ground_truth |
| 25 | + # check if the final messages contains the ground truth |
| 26 | + |
| 27 | + async with AsyncOpenAI( |
| 28 | + api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1" |
| 29 | + ) as client: |
| 30 | + response = await client.chat.completions.create( |
| 31 | + model="accounts/fireworks/models/kimi-k2-instruct-0905", |
| 32 | + messages=[ |
| 33 | + { |
| 34 | + "role": "system", |
| 35 | + "content": "You are judging the output of the model versus the ground truth. Return score = 1 if the output contains the ground truth, 0 otherwise.", |
| 36 | + }, |
| 37 | + { |
| 38 | + "role": "user", |
| 39 | + "content": "Final model output: {row.messages[-1].content}\nGround truth: {ground_truth}", |
| 40 | + }, |
| 41 | + ], |
| 42 | + response_format={ |
| 43 | + "type": "json_schema", |
| 44 | + "json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()}, |
| 45 | + }, |
| 46 | + ) |
| 47 | + response_text = response.choices[0].message.content |
| 48 | + logger.info("response_text: %s", response_text) |
| 49 | + score = json.loads(response_text or "{}")["score"] |
31 | 50 | row.evaluation_result = EvaluateResult( |
32 | | - score=0, |
33 | | - reason="No tool calls made", |
| 51 | + score=score, |
| 52 | + reason=response_text, |
34 | 53 | ) |
35 | | - return row |
36 | | - |
37 | | - row.evaluation_result = EvaluateResult( |
38 | | - score=1, |
39 | | - reason="At least one tool call was made", |
40 | | - ) |
41 | 54 | return row |
0 commit comments