diff --git a/tests/chinook/pydantic/agent.py b/tests/chinook/pydantic/agent.py index 33bf6d58..ec6e27a1 100644 --- a/tests/chinook/pydantic/agent.py +++ b/tests/chinook/pydantic/agent.py @@ -33,7 +33,7 @@ def setup_agent(orchestrator_agent_model: Model): ) @agent.tool(retries=5) - def execute_sql(ctx: RunContext, query: str) -> dict: + def execute_sql(ctx: RunContext, query: str) -> str: try: cursor.execute(query) # Get column headers from cursor description @@ -69,7 +69,7 @@ def execute_sql(ctx: RunContext, query: str) -> dict: async def main(): model = OpenAIModel( - model="accounts/fireworks/models/kimi-k2-instruct", + "accounts/fireworks/models/kimi-k2-instruct", provider="fireworks", ) agent = setup_agent(model) diff --git a/tests/chinook/pydantic/test_pydantic_chinook.py b/tests/chinook/pydantic/test_pydantic_chinook.py index 0233a227..400bfa25 100644 --- a/tests/chinook/pydantic/test_pydantic_chinook.py +++ b/tests/chinook/pydantic/test_pydantic_chinook.py @@ -84,68 +84,3 @@ class Response(BaseModel): reason=result.output.reason, ) return row - - -@pytest.mark.skipif( - os.environ.get("CI") == "true", - reason="Only run this test locally (skipped in CI)", -) -@pytest.mark.asyncio -@evaluation_test( - input_rows=[collect_dataset()], - completion_params=[ - { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - }, - ], - rollout_processor=PydanticAgentRolloutProcessor(agent_factory), - mode="pointwise", -) -async def test_complex_queries(row: EvaluationRow) -> EvaluationRow: - """ - Complex queries for the Chinook database - """ - last_assistant_message = row.last_assistant_message() - if last_assistant_message is None: - row.evaluation_result = EvaluateResult( - score=0.0, - reason="No assistant message found", - ) - elif not last_assistant_message.content: - row.evaluation_result = EvaluateResult( - score=0.0, - reason="No assistant message found", - ) - else: - model = OpenAIModel( - "accounts/fireworks/models/kimi-k2-instruct", - provider="fireworks", - ) - - class Response(BaseModel): - """ - A score between 0.0 and 1.0 indicating whether the response is correct. - """ - - score: float - - """ - A short explanation of why the response is correct or incorrect. - """ - reason: str - - comparison_agent = Agent( - model=model, - system_prompt=LLM_JUDGE_PROMPT, - output_type=Response, - output_retries=5, - ) - result = await comparison_agent.run( - f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}" - ) - row.evaluation_result = EvaluateResult( - score=result.output.score, - reason=result.output.reason, - ) - return row diff --git a/tests/chinook/pydantic/test_pydantic_complex_queries.py b/tests/chinook/pydantic/test_pydantic_complex_queries.py new file mode 100644 index 00000000..8b750f03 --- /dev/null +++ b/tests/chinook/pydantic/test_pydantic_complex_queries.py @@ -0,0 +1,91 @@ +import os +from pydantic import BaseModel +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +import pytest + +from eval_protocol.models import EvaluateResult, EvaluationRow +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.types import RolloutProcessorConfig +from tests.chinook.dataset import collect_dataset +from tests.chinook.pydantic.agent import setup_agent +from tests.pytest.test_pydantic_agent import PydanticAgentRolloutProcessor + +LLM_JUDGE_PROMPT = ( + "Your job is to compare the response to the expected answer.\n" + "The response will be a narrative report of the query results.\n" + "If the response contains the same or well summarized information as the expected answer, return 1.0.\n" + "If the response does not contain the same information or is missing information, return 0.0." +) + + +def agent_factory(config: RolloutProcessorConfig) -> Agent: + model_name = config.completion_params["model"] + provider = config.completion_params["provider"] + model = OpenAIModel(model_name, provider=provider) + return setup_agent(model) + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()], + completion_params=[ + { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), + mode="pointwise", +) +async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow: + """ + Evaluation of complex queries for the Chinook database using PydanticAI + """ + last_assistant_message = row.last_assistant_message() + if last_assistant_message is None: + row.evaluation_result = EvaluateResult( + score=0.0, + reason="No assistant message found", + ) + elif not last_assistant_message.content: + row.evaluation_result = EvaluateResult( + score=0.0, + reason="No assistant message found", + ) + else: + model = OpenAIModel( + "accounts/fireworks/models/kimi-k2-instruct", + provider="fireworks", + ) + + class Response(BaseModel): + """ + A score between 0.0 and 1.0 indicating whether the response is correct. + """ + + score: float + + """ + A short explanation of why the response is correct or incorrect. + """ + reason: str + + comparison_agent = Agent( + model=model, + system_prompt=LLM_JUDGE_PROMPT, + output_type=Response, + output_retries=5, + ) + result = await comparison_agent.run( + f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}" + ) + row.evaluation_result = EvaluateResult( + score=result.output.score, + reason=result.output.reason, + ) + return row