Skip to content

Commit 88ac21a

Browse files
author
Dylan Huang
committed
ensure tools are present from agent rollout
1 parent 075f39c commit 88ac21a

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,16 @@ def __init__(self, model: str, row: EvaluationRow, config_path: str, logger: Dat
2929
self.evaluation_row: EvaluationRow = row
3030
self._policy = LiteLLMPolicy(model_id=model)
3131
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
32-
self.tools: Union[List[ChatCompletionToolParam], NotGiven] = NOT_GIVEN
3332
self.logger: DatasetLogger = logger
3433

3534
async def setup(self):
3635
if self.mcp_client:
3736
await self.mcp_client.connect_to_servers()
3837

3938
async def _get_tools(self) -> Optional[List[ChatCompletionToolParam]]:
40-
if self.tools is NOT_GIVEN:
41-
self.tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
42-
return self.tools
39+
if self.evaluation_row.tools is None:
40+
self.evaluation_row.tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
41+
return self.evaluation_row.tools
4342

4443
@property
4544
def messages(self) -> list[Message]:

tests/pytest/test_pytest_mcp_config.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from datetime import datetime
2-
from typing import List
3-
1+
import pytest
2+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
43
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
54
from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test
65

@@ -41,3 +40,53 @@ def test_pytest_mcp_config(row: EvaluationRow) -> EvaluationRow:
4140
reason="At least one tool call was made",
4241
)
4342
return row
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_pytest_tools_are_added_to_row():
47+
class TrackingLogger(DatasetLogger):
48+
"""Custom logger that ensures that the final row is in an error state."""
49+
50+
def __init__(self, rollouts: dict[str, EvaluationRow]):
51+
self.rollouts = rollouts
52+
53+
def log(self, row: EvaluationRow):
54+
self.rollouts[row.execution_metadata.rollout_id] = row
55+
56+
def read(self):
57+
return []
58+
59+
input_messages = [
60+
[
61+
Message(
62+
role="system",
63+
content="You are a helpful assistant that can answer questions about Fireworks.",
64+
),
65+
]
66+
]
67+
completion_params_list = [
68+
{"model": "dummy/local-model"},
69+
]
70+
71+
rollouts: dict[str, EvaluationRow] = {}
72+
logger = TrackingLogger(rollouts)
73+
74+
@evaluation_test(
75+
input_messages=input_messages,
76+
completion_params=completion_params_list,
77+
rollout_processor=AgentRolloutProcessor(),
78+
mode="pointwise",
79+
mcp_config_path="tests/pytest/mcp_configurations/mock_discord_mcp_config.json",
80+
logger=logger,
81+
)
82+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
83+
return row
84+
85+
await eval_fn(input_messages=input_messages, completion_params=completion_params_list[0])
86+
87+
# ensure that the row has tools that were set during AgentRolloutProcessor
88+
assert len(rollouts) == 1
89+
row = list(rollouts.values())[0]
90+
assert sorted([tool["function"].name for tool in row.tools]) == sorted(
91+
["list_servers", "get_channels", "read_messages"]
92+
)

0 commit comments

Comments
 (0)