Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ repos:
hooks:
- id: basedpyright
args: ["--level", "error"]
env:
NODE_OPTIONS: "--max-old-space-size=4096"
# Only check Python files in the main package to reduce memory usage
files: ^eval_protocol/.*\.py$
15 changes: 9 additions & 6 deletions eval_protocol/benchmarks/test_tau_bench_airline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
import litellm
from litellm.exceptions import RateLimitError, APIConnectionError
from vendor.tau2.data_model.message import (
AssistantMessage,
SystemMessage,
Expand Down Expand Up @@ -125,8 +126,8 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
server_script_path=_get_server_script_path(),
exception_handler_config=ExceptionHandlerConfig(
retryable_exceptions={
litellm.RateLimitError,
litellm.APIConnectionError,
RateLimitError,
APIConnectionError,
}
),
)
Expand Down Expand Up @@ -159,8 +160,10 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
role = msg.role
content = msg.content

# Normalize content to str for tau2 message models
text_content = content if isinstance(content, str) or content is None else ""
if role == "system":
trajectory_objects.append(SystemMessage(role=role, content=content))
trajectory_objects.append(SystemMessage(role=role, content=text_content))
elif role == "assistant":
tau2_tool_calls = []
if msg.tool_calls:
Expand All @@ -173,12 +176,12 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
)
tau2_tool_calls.append(tau2_tool_call)

trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
trajectory_objects.append(AssistantMessage(role=role, content=text_content, tool_calls=tau2_tool_calls))
elif role == "user":
trajectory_objects.append(UserMessage(role=role, content=content))
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))

reward = 1.0

Expand Down
15 changes: 9 additions & 6 deletions eval_protocol/benchmarks/test_tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
import litellm
from litellm.exceptions import RateLimitError, APIConnectionError
from vendor.tau2.data_model.message import (
AssistantMessage,
SystemMessage,
Expand Down Expand Up @@ -115,8 +116,8 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
server_script_path=get_server_script_path(),
exception_handler_config=ExceptionHandlerConfig(
retryable_exceptions={
litellm.RateLimitError,
litellm.APIConnectionError,
RateLimitError,
APIConnectionError,
}
),
)
Expand Down Expand Up @@ -149,8 +150,10 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
role = msg.role
content = msg.content

# Normalize content to str for tau2 message models
text_content = content if isinstance(content, str) or content is None else ""
if role == "system":
trajectory_objects.append(SystemMessage(role=role, content=content))
trajectory_objects.append(SystemMessage(role=role, content=text_content))
elif role == "assistant":
tau2_tool_calls = []
if msg.tool_calls:
Expand All @@ -163,12 +166,12 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
)
tau2_tool_calls.append(tau2_tool_call)

trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
trajectory_objects.append(AssistantMessage(role=role, content=text_content, tool_calls=tau2_tool_calls))
elif role == "user":
trajectory_objects.append(UserMessage(role=role, content=content))
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))

reward = 1.0

Expand Down
22 changes: 18 additions & 4 deletions eval_protocol/cli_commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

# TODO: Consider moving subprocess_manager functions to a more central location if used by core CLI
try:
# Import functions with explicit names to match expected signatures
from development.utils.subprocess_manager import (
start_ngrok_and_get_url, # Added ngrok function
start_process,
start_serveo_and_get_url,
stop_process,
start_ngrok_and_get_url as _start_ngrok_and_get_url,
start_process as _start_process,
start_serveo_and_get_url as _start_serveo_and_get_url,
stop_process as _stop_process,
)
except ImportError:
# Fallback implementations when development module is not available
Expand Down Expand Up @@ -56,6 +57,19 @@ def start_ngrok_and_get_url(local_port, log_path):
"""Fallback ngrok tunnel - returns None to indicate unavailable."""
print("ngrok tunneling not available - development module not found")
return None, None
else:
# Wrap imported helpers to present consistent, simple signatures used below
def start_process(command, log_path, env=None):
return _start_process(command=command, log_file_path=log_path, env=env)

def stop_process(pid):
return _stop_process(pid)

def start_serveo_and_get_url(local_port, log_path):
return _start_serveo_and_get_url(local_port=local_port, log_path=log_path)

def start_ngrok_and_get_url(local_port, log_path):
return _start_ngrok_and_get_url(local_port=local_port, ngrok_log_file=log_path)


from eval_protocol.auth import get_fireworks_account_id
Expand Down
9 changes: 6 additions & 3 deletions eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import litellm
from litellm import acompletion, completion
from litellm.caching import Cache, DualCache, InMemoryCache, RedisCache
from litellm.caching.caching import Cache
from litellm.caching.dual_cache import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.caching.redis_cache import RedisCache

from .base_policy import LLMBasePolicy

Expand Down Expand Up @@ -108,13 +111,13 @@ def _setup_litellm_caching(
logger.info("🗄️ Initialized dual caching (memory + Redis)")

elif cache_type == "disk":
from litellm.caching import DiskCache
from litellm.caching.disk_cache import DiskCache

litellm.cache = DiskCache()
logger.info("🗄️ Initialized disk caching")

elif cache_type == "s3":
from litellm.caching import S3Cache
from litellm.caching.s3_cache import S3Cache

litellm.cache = S3Cache()
logger.info("🗄️ Initialized S3 caching")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

# ListToolsResult is not in mcp.client.session, likely in mcp.types or mcp.shared.message
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, SessionMessage
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession

# Assuming ListToolsResult is in mcp.types, which is imported as types
# If not, this will need further correction. For now, we'll use types.ListToolsResult
Expand Down
8 changes: 5 additions & 3 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from openai.types import CompletionUsage
from openai.types.chat.chat_completion_message import (
ChatCompletionMessageToolCall,
FunctionCall,
)
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from pydantic import BaseModel, ConfigDict, Field

from eval_protocol.get_pep440_version import get_pep440_version
Expand Down Expand Up @@ -519,7 +521,7 @@ class EvaluationRow(BaseModel):

# Input-related metadata (grouped together for cleaner organization)
input_metadata: InputMetadata = Field(
default_factory=InputMetadata,
default_factory=lambda: InputMetadata(),
description="Metadata related to the input (dataset info, model config, session data, etc.).",
)

Expand All @@ -539,7 +541,7 @@ class EvaluationRow(BaseModel):
)

execution_metadata: ExecutionMetadata = Field(
default_factory=ExecutionMetadata,
default_factory=lambda: ExecutionMetadata(),
description="Metadata about the execution of the evaluation.",
)

Expand Down
21 changes: 18 additions & 3 deletions eval_protocol/rewards/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,22 @@
import re
from typing import Any, Callable, Dict, List, Optional, Union, cast

from ..models import EvaluateResult, Message, MetricResult
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam


def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
"""Coerce Message.content into a plain string for regex and comparisons."""
if content is None:
return ""
if isinstance(content, str):
return content
# List[ChatCompletionContentPartTextParam]
try:
return "\n".join(part.text for part in content)
except Exception:
return ""


from ..typed_interface import reward_function


Expand Down Expand Up @@ -334,7 +349,7 @@ def accuracy_reward(
model_last_message = messages[-1]
if isinstance(model_last_message, Message):
if model_last_message.role == "assistant" and model_last_message.content is not None:
model_response_text = model_last_message.content
model_response_text = _to_text(model_last_message.content)
else:
return EvaluateResult(
score=0.0,
Expand Down Expand Up @@ -386,7 +401,7 @@ def accuracy_reward(
first_gt_message = ground_truth[0]
if isinstance(first_gt_message, Message):
if first_gt_message.content is not None:
ground_truth_comparison_text = first_gt_message.content
ground_truth_comparison_text = _to_text(first_gt_message.content)
else:
return EvaluateResult(
score=0.0,
Expand Down
12 changes: 10 additions & 2 deletions eval_protocol/rewards/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from typing import Any, Dict, List, Optional, Union

from ..models import EvaluateResult, Message, MetricResult
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..typed_interface import reward_function
from .function_calling import (
calculate_jaccard_similarity,
Expand Down Expand Up @@ -54,7 +54,15 @@ def json_schema_reward(

if isinstance(last_message, Message):
if last_message.role == "assistant" and last_message.content is not None:
content_text = last_message.content
# Coerce to string if content is list parts
if isinstance(last_message.content, str):
content_text = last_message.content
else:
try:
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
content_text = "\n".join(p.text for p in parts)
except Exception:
content_text = ""
else:
return EvaluateResult(
score=0.0,
Expand Down
23 changes: 14 additions & 9 deletions eval_protocol/rewards/language_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from ..models import EvaluateResult, Message, MetricResult
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..typed_interface import reward_function

# Dictionary mapping language codes to common words/patterns in that language
Expand Down Expand Up @@ -560,12 +560,7 @@ def language_consistency_reward(
Returns:
EvaluateResult with score based on language consistency.
"""
if (
not messages
or not isinstance(messages[-1], Message)
or messages[-1].role != "assistant"
or messages[-1].content is None
):
if not messages or not isinstance(messages[-1], Message) or messages[-1].role != "assistant":
return EvaluateResult(
score=0.0,
reason="Invalid or missing assistant response in messages.",
Expand All @@ -578,7 +573,17 @@ def language_consistency_reward(
},
)

text_to_evaluate = messages[-1].content
def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
except Exception:
return ""

text_to_evaluate = _to_text(messages[-1].content)

# For test_spanish_consistency - special handling for Spanish test case
if "está escrita completamente en español" in text_to_evaluate:
Expand All @@ -593,7 +598,7 @@ def language_consistency_reward(
prompt_messages = messages[:-1]
for msg in prompt_messages:
if isinstance(msg, Message) and msg.role == "user": # Decorator ensures msg is Message
content_text: str = msg.content if msg.content is not None else ""
content_text: str = _to_text(msg.content)
if "in Spanish" in content_text:
target_language = "es"
break
Expand Down
17 changes: 15 additions & 2 deletions eval_protocol/rewards/repetition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import re
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from ..models import EvaluateResult, Message, MetricResult
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam


def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
except Exception:
return ""


from ..typed_interface import reward_function


Expand Down Expand Up @@ -94,7 +107,7 @@ def repetition_penalty_reward(
)
},
)
text = response.content or ""
text = _to_text(response.content)
elif isinstance(response, dict):
if response.get("role") != "assistant":
return EvaluateResult(
Expand Down
19 changes: 16 additions & 3 deletions eval_protocol/rewards/tag_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import re
from typing import Any, Dict, List, Set, Union

from ..models import EvaluateResult, Message, MetricResult
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam


def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
except Exception:
return ""


from ..typed_interface import reward_function


Expand Down Expand Up @@ -46,7 +59,7 @@ def tag_count_reward(

response = messages[-1]

if response.role != "assistant" or not response.content:
if response.role != "assistant" or response.content is None:
return EvaluateResult(
score=0.0,
reason="No assistant response found or response has no content",
Expand All @@ -58,7 +71,7 @@ def tag_count_reward(
)
},
)
text: str = response.content
text: str = _to_text(response.content)

tag_metrics = {}
found_tags: Set[str] = set()
Expand Down
Loading
Loading