Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion eval_protocol/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def run_command(
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = process.communicate()[0].strip().decode()
stdout_bytes = process.communicate()[0]
stdout_raw = stdout_bytes.decode() if isinstance(stdout_bytes, (bytes, bytearray)) else stdout_bytes
stdout = str(stdout_raw).strip()
if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
Expand Down
38 changes: 24 additions & 14 deletions eval_protocol/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,36 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union, cast, TypeAlias

from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message

logger = logging.getLogger(__name__)

try:
# Import at runtime if available
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
from google.cloud import bigquery as _bigquery_runtime # type: ignore
from google.cloud.exceptions import Forbidden, NotFound
from google.oauth2 import service_account

BIGQUERY_AVAILABLE = True
except ImportError:
# Provide fallbacks for type checking/runtime when package is missing
DefaultCredentialsError = Exception # type: ignore[assignment]
Forbidden = Exception # type: ignore[assignment]
NotFound = Exception # type: ignore[assignment]
service_account: Any
service_account = None
_bigquery_runtime = None # type: ignore[assignment]
BIGQUERY_AVAILABLE = False
# Optional dependency: avoid noisy warnings during import
logger.debug("Google Cloud BigQuery not installed. Optional feature disabled.")

# Avoid importing BigQuery types at runtime for annotations when not installed
if TYPE_CHECKING:
from google.cloud import bigquery as _bigquery_type

QueryParameterType = Union[
_bigquery_type.ScalarQueryParameter,
_bigquery_type.ArrayQueryParameter,
]
else:
QueryParameterType = Any
# Simple type aliases to avoid importing optional google types under pyright
QueryParameterType: TypeAlias = Any
BigQueryClient: TypeAlias = Any
QueryJobConfig: TypeAlias = Any

# Type alias for transformation function
TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]]
Expand Down Expand Up @@ -98,7 +100,13 @@
client_args["location"] = location

client_args.update(client_kwargs)
self.client = bigquery.Client(**client_args)
# Use runtime alias to avoid basedpyright import symbol error when lib is missing
if _bigquery_runtime is None:
raise ImportError(
"google-cloud-bigquery is not installed. Install with: pip install 'eval-protocol[bigquery]'"
)
# Avoid strict typing on optional dependency
self.client = _bigquery_runtime.Client(**client_args) # type: ignore[no-untyped-call, assignment]

except DefaultCredentialsError as e:
logger.error("Failed to authenticate with BigQuery: %s", e)
Expand Down Expand Up @@ -139,7 +147,9 @@
"""
try:
# Configure query job
job_config = bigquery.QueryJobConfig()
if _bigquery_runtime is None:
raise RuntimeError("BigQuery runtime not available")
job_config = _bigquery_runtime.QueryJobConfig() # type: ignore[no-untyped-call]
if query_params:
job_config.query_parameters = query_params
if self.location:
Expand Down Expand Up @@ -184,7 +194,7 @@
except (NotFound, Forbidden) as e:
logger.error("BigQuery access error: %s", e)
raise
except Exception as e:

Check failure on line 197 in eval_protocol/adapters/bigquery.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Except clause is unreachable because exception is already handled   "Exception" is a subclass of "Exception" (reportUnusedExcept)
logger.error("Error executing BigQuery query: %s", e)
raise

Expand Down
15 changes: 10 additions & 5 deletions eval_protocol/adapters/langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from typing import List
from typing import Any, Dict, List, Optional

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

Expand Down Expand Up @@ -49,10 +49,10 @@ def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
parts.append(item)
content = "\n".join(parts)

tool_calls_payload = None
tool_calls_payload: Optional[List[Dict[str, Any]]] = None

def _normalize_tool_calls(tc_list: list) -> list[dict]:
mapped: List[dict] = []
def _normalize_tool_calls(tc_list: List[Any]) -> List[Dict[str, Any]]:
mapped: List[Dict[str, Any]] = []
for call in tc_list:
if not isinstance(call, dict):
continue
Expand Down Expand Up @@ -104,8 +104,13 @@ def _normalize_tool_calls(tc_list: list) -> list[dict]:
if collected:
reasoning_content = "\n\n".join([s for s in collected if s]) or None

# Message.tool_calls expects List[ChatCompletionMessageToolCall] | None.
# We pass through Dicts at runtime but avoid type error by casting.
ep_msg = Message(
role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content
role="assistant",
content=content,
tool_calls=tool_calls_payload, # type: ignore[arg-type]
reasoning_content=reasoning_content,
)
_dbg_print(
"[EP-Ser] -> EP Message:",
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/adapters/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import logging
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, cast

from eval_protocol.models import EvaluationRow, InputMetadata, Message

logger = logging.getLogger(__name__)

try:
from langfuse import Langfuse

Check failure on line 16 in eval_protocol/adapters/langfuse.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

"Langfuse" is not exported from module "langfuse"   Import from "langfuse._client.client" instead (reportPrivateImportUsage)

LANGFUSE_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -63,7 +63,7 @@
if not LANGFUSE_AVAILABLE:
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")

self.client = Langfuse(public_key=public_key, secret_key=secret_key, host=host)
self.client = cast(Any, Langfuse)(public_key=public_key, secret_key=secret_key, host=host)
self.project_id = project_id

def get_evaluation_rows(
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/benchmarks/test_gpqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _extract_abcd_letter(text: str) -> str | None:


def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
# assert that all the messages just have a plain .content string field
assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field"
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]


Expand Down
1 change: 1 addition & 0 deletions eval_protocol/benchmarks/test_tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
task = Task(
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
) # id and user_scenario are required for the Task type but not used in calculating reward
assert task.evaluation_criteria is not None, "Task evaluation criteria is None"

if RewardType.DB in task.evaluation_criteria.reward_basis:
env_reward_info = EnvironmentEvaluator.calculate_reward(
Expand Down
15 changes: 7 additions & 8 deletions eval_protocol/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,21 @@ def load_metric_folder(self, metric_name, folder_path):
for keyword in decorator_node.keywords:
if keyword.arg == "requirements":
if isinstance(keyword.value, ast.List):
reqs = []
reqs: List[str] = []
for elt in keyword.value.elts:
if isinstance(elt, ast.Constant) and isinstance(
elt.value, str
): # Python 3.8+
reqs.append(elt.value)
if isinstance(elt, ast.Constant): # Python 3.8+
if isinstance(elt.value, str):
reqs.append(cast(str, elt.value))
elif isinstance(elt, ast.Str): # Python < 3.8
reqs.append(elt.s)
reqs.append(cast(str, elt.s))
if reqs:
metric_requirements_list = cast(List[str], reqs)
elif isinstance(keyword.value, ast.Constant) and isinstance(
keyword.value.value, str
): # Python 3.8+ (single req string)
metric_requirements_list = [keyword.value.value]
metric_requirements_list = [cast(str, keyword.value.value)]
elif isinstance(keyword.value, ast.Str): # Python < 3.8 (single req string)
metric_requirements_list = [keyword.value.s]
metric_requirements_list = [cast(str, keyword.value.s)]
break
if metric_requirements_list:
break
Expand Down
20 changes: 15 additions & 5 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,24 +441,34 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
# Extract data plane results (observation only)
if tool_result.content and len(tool_result.content) > 0:
content = tool_result.content[0]
if hasattr(content, "text"):
# Safely attempt to read a "text" attribute if present across content types
text_attr = getattr(content, "text", None)
if isinstance(text_attr, str):
content_text = text_attr
elif isinstance(text_attr, list):
# text can also be an array of parts with optional .text fields
content_text = "".join([getattr(p, "text", "") for p in text_attr])
else:
content_text = None

if isinstance(content_text, str):
# Fix: Handle empty or invalid JSON responses gracefully
if not content.text or content.text.strip() == "":
if content_text.strip() == "":
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
observation = {
"observation": "empty_response",
"session_id": session.session_id,
}
else:
try:
observation = json.loads(content.text)
observation = json.loads(content_text)
except json.JSONDecodeError as e:
logger.warning(
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}"
)
# Create a structured response from the raw text
observation = {
"observation": content.text,
"observation": content_text,
"session_id": session.session_id,
"error": "invalid_json_response",
}
Expand Down
51 changes: 32 additions & 19 deletions eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,19 @@ def _setup_litellm_caching(
logger.info("🗄️ Initialized disk caching")

elif cache_type == "s3":
from litellm.caching.s3_cache import S3Cache

litellm.cache = S3Cache()
logger.info("🗄️ Initialized S3 caching")
try:
from litellm.caching.s3_cache import S3Cache

# Some versions require positional or named 's3_bucket_name'
s3_bucket_name = os.getenv("LITELLM_S3_BUCKET")
if not s3_bucket_name:
raise ValueError("Missing LITELLM_S3_BUCKET for S3 cache")
# Use explicit arg name expected by basedpyright
litellm.cache = S3Cache(s3_bucket_name=s3_bucket_name)
logger.info("🗄️ Initialized S3 caching for bucket %s", s3_bucket_name)
except Exception as e:
logger.warning(f"Failed to initialize S3 cache ({e}); falling back to in-memory cache")
litellm.cache = Cache()

except Exception as e:
logger.warning(f"Failed to setup {cache_type} caching: {e}. Falling back to in-memory cache.")
Expand All @@ -147,7 +156,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
clean_messages.append(clean_msg)
return clean_messages

async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Make an LLM API call with retry logic and caching.

Expand All @@ -162,7 +171,7 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
clean_messages = self._clean_messages_for_api(messages)

# Prepare request parameters
request_params = {
request_params: Dict[str, Any] = {
"messages": clean_messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
Expand All @@ -188,7 +197,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
response = await acompletion(model=self.model_id, **request_params)

# Log cache hit/miss for monitoring
cache_hit = getattr(response, "_hidden_params", {}).get("cache_hit", False)
hidden = getattr(response, "_hidden_params", {})
cache_hit = hidden.get("cache_hit", False) if isinstance(hidden, dict) else False
if cache_hit:
logger.debug(f"🎯 Cache hit for model: {self.model_id}")
else:
Expand All @@ -199,31 +209,34 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
"choices": [
{
"message": {
"role": response.choices[0].message.role,
"content": response.choices[0].message.content,
"role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"),
"content": getattr(getattr(response.choices[0], "message", object()), "content", None),
"tool_calls": (
[
{
"id": tc.id,
"type": tc.type,
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", "function"),
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
"name": getattr(getattr(tc, "function", None), "name", "tool"),
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
},
}
for tc in (response.choices[0].message.tool_calls or [])
for tc in (
getattr(getattr(response.choices[0], "message", object()), "tool_calls", [])
or []
)
]
if response.choices[0].message.tool_calls
if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None)
else []
),
},
"finish_reason": response.choices[0].finish_reason,
"finish_reason": getattr(response.choices[0], "finish_reason", None),
}
],
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"prompt_tokens": getattr(getattr(response, "usage", {}), "prompt_tokens", 0),
"completion_tokens": getattr(getattr(response, "usage", {}), "completion_tokens", 0),
"total_tokens": getattr(getattr(response, "usage", {}), "total_tokens", 0),
},
}

Expand Down
5 changes: 3 additions & 2 deletions eval_protocol/mcp_servers/tau2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
def get_server_script_path() -> str:
"""Get the path to the tau2 MCP server script."""
try:
# Try to get from installed package
with importlib.resources.as_file(importlib.resources.files(__package__) / "server.py") as server_path:
# Try to get from installed package. __package__ can be None during some tooling.
package = __package__ if __package__ is not None else __name__
with importlib.resources.as_file(importlib.resources.files(package) / "server.py") as server_path:
return str(server_path)
except (ImportError, FileNotFoundError):
# Fallback for development environment
Expand Down
7 changes: 4 additions & 3 deletions eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]):
@reward_function
def tau2_airline_eval(
messages: List[Message],
nl_assertions: List[str] = None,
communicate_info: List[str] = None,
actions: List[dict] = None,
nl_assertions: Optional[List[str]] = None,
communicate_info: Optional[List[str]] = None,
actions: Optional[List[dict]] = None,
**kwargs,
) -> EvaluateResult:
"""
Expand All @@ -726,6 +726,7 @@ def tau2_airline_eval(
for msg in messages:
role = msg.role
content = msg.content
assert isinstance(content, str), "Content must be a string"

if role == "system":
trajectory_objects.append(SystemMessage(role=role, content=content))
Expand Down
Loading
Loading