Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ jobs:
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }}
SUPABASE_PASSWORD: ${{ secrets.SUPABASE_PASSWORD }}
SUPABASE_HOST: ${{ secrets.SUPABASE_HOST }}
SUPABASE_PORT: ${{ secrets.SUPABASE_PORT }}
SUPABASE_DATABASE: ${{ secrets.SUPABASE_DATABASE }}
SUPABASE_USER: ${{ secrets.SUPABASE_USER }}
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
run: |
# Run most tests in parallel, but explicitly ignore tests that manage their own servers or are slow
Expand Down
7 changes: 7 additions & 0 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,13 @@ def get_assistant_messages(self) -> List[Message]:
"""Returns only the assistant messages from the conversation."""
return [msg for msg in self.messages if msg.role == "assistant"]

def last_assistant_message(self) -> Optional[Message]:
"""Returns the last assistant message from the conversation. Returns None if none found."""
assistant_messages = self.get_assistant_messages()
if not assistant_messages:
return None
return assistant_messages[-1]

def get_user_messages(self) -> List[Message]:
"""Returns only the user messages from the conversation."""
return [msg for msg in self.messages if msg.role == "user"]
Expand Down
10 changes: 6 additions & 4 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import types
from typing import List

from attr import dataclass
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam

from eval_protocol.models import EvaluationRow, Message
Expand All @@ -23,6 +24,7 @@
UserPromptPart,
)
from pydantic_ai.providers.openai import OpenAIProvider
from typing_extensions import TypedDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,10 +60,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
)
kwargs = {}
for model_name, model_config in config.completion_params["model"].items():
kwargs[model_name] = OpenAIModel(
model_config["model"],
provider=model_config["provider"],
for k, v in config.completion_params["model"].items():
kwargs[k] = OpenAIModel(
v["model"],
provider=v["provider"],
)
agent = setup_agent(**kwargs)
model = None
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ svgbench = [
pydantic = [
"pydantic-ai",
]
supabase = [
"supabase>=2.18.1",
]
chinook = [
"psycopg2-binary>=2.9.10",
]

[tool.pytest.ini_options]
addopts = "-q"
Expand Down
Loading
Loading