|
1 | | -from datetime import datetime |
2 | | -from typing import List |
3 | | - |
| 1 | +import pytest |
| 2 | +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
4 | 3 | from eval_protocol.models import EvaluateResult, EvaluationRow, Message |
5 | 4 | from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test |
6 | 5 |
|
@@ -41,3 +40,53 @@ def test_pytest_mcp_config(row: EvaluationRow) -> EvaluationRow: |
41 | 40 | reason="At least one tool call was made", |
42 | 41 | ) |
43 | 42 | return row |
| 43 | + |
| 44 | + |
| 45 | +@pytest.mark.asyncio |
| 46 | +async def test_pytest_tools_are_added_to_row(): |
| 47 | + class TrackingLogger(DatasetLogger): |
| 48 | + """Custom logger that ensures that the final row is in an error state.""" |
| 49 | + |
| 50 | + def __init__(self, rollouts: dict[str, EvaluationRow]): |
| 51 | + self.rollouts = rollouts |
| 52 | + |
| 53 | + def log(self, row: EvaluationRow): |
| 54 | + self.rollouts[row.execution_metadata.rollout_id] = row |
| 55 | + |
| 56 | + def read(self): |
| 57 | + return [] |
| 58 | + |
| 59 | + input_messages = [ |
| 60 | + [ |
| 61 | + Message( |
| 62 | + role="system", |
| 63 | + content="You are a helpful assistant that can answer questions about Fireworks.", |
| 64 | + ), |
| 65 | + ] |
| 66 | + ] |
| 67 | + completion_params_list = [ |
| 68 | + {"model": "dummy/local-model"}, |
| 69 | + ] |
| 70 | + |
| 71 | + rollouts: dict[str, EvaluationRow] = {} |
| 72 | + logger = TrackingLogger(rollouts) |
| 73 | + |
| 74 | + @evaluation_test( |
| 75 | + input_messages=input_messages, |
| 76 | + completion_params=completion_params_list, |
| 77 | + rollout_processor=AgentRolloutProcessor(), |
| 78 | + mode="pointwise", |
| 79 | + mcp_config_path="tests/pytest/mcp_configurations/mock_discord_mcp_config.json", |
| 80 | + logger=logger, |
| 81 | + ) |
| 82 | + def eval_fn(row: EvaluationRow) -> EvaluationRow: |
| 83 | + return row |
| 84 | + |
| 85 | + await eval_fn(input_messages=input_messages, completion_params=completion_params_list[0]) |
| 86 | + |
| 87 | + # ensure that the row has tools that were set during AgentRolloutProcessor |
| 88 | + assert len(rollouts) == 1 |
| 89 | + row = list(rollouts.values())[0] |
| 90 | + assert sorted([tool["function"].name for tool in row.tools]) == sorted( |
| 91 | + ["list_servers", "get_channels", "read_messages"] |
| 92 | + ) |
0 commit comments