diff --git a/eval_protocol/_version.py b/eval_protocol/_version.py index 4a548c9a..988fcdfc 100644 --- a/eval_protocol/_version.py +++ b/eval_protocol/_version.py @@ -121,7 +121,9 @@ def run_command( if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = process.communicate()[0].strip().decode() + stdout_bytes = process.communicate()[0] + stdout_raw = stdout_bytes.decode() if isinstance(stdout_bytes, (bytes, bytearray)) else stdout_bytes + stdout = str(stdout_raw).strip() if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) diff --git a/eval_protocol/adapters/bigquery.py b/eval_protocol/adapters/bigquery.py index 7b79884b..db4cbda0 100644 --- a/eval_protocol/adapters/bigquery.py +++ b/eval_protocol/adapters/bigquery.py @@ -7,34 +7,36 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union, cast, TypeAlias from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message logger = logging.getLogger(__name__) try: + # Import at runtime if available from google.auth.exceptions import DefaultCredentialsError - from google.cloud import bigquery + from google.cloud import bigquery as _bigquery_runtime # type: ignore from google.cloud.exceptions import Forbidden, NotFound from google.oauth2 import service_account BIGQUERY_AVAILABLE = True except ImportError: + # Provide fallbacks for type checking/runtime when package is missing + DefaultCredentialsError = Exception # type: ignore[assignment] + Forbidden = Exception # type: ignore[assignment] + NotFound = Exception # type: ignore[assignment] + service_account: Any + service_account = None + _bigquery_runtime = None # type: ignore[assignment] BIGQUERY_AVAILABLE = False # Optional dependency: avoid noisy warnings during import logger.debug("Google Cloud BigQuery not installed. Optional feature disabled.") -# Avoid importing BigQuery types at runtime for annotations when not installed -if TYPE_CHECKING: - from google.cloud import bigquery as _bigquery_type - - QueryParameterType = Union[ - _bigquery_type.ScalarQueryParameter, - _bigquery_type.ArrayQueryParameter, - ] -else: - QueryParameterType = Any +# Simple type aliases to avoid importing optional google types under pyright +QueryParameterType: TypeAlias = Any +BigQueryClient: TypeAlias = Any +QueryJobConfig: TypeAlias = Any # Type alias for transformation function TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]] @@ -98,7 +100,13 @@ def __init__( client_args["location"] = location client_args.update(client_kwargs) - self.client = bigquery.Client(**client_args) + # Use runtime alias to avoid basedpyright import symbol error when lib is missing + if _bigquery_runtime is None: + raise ImportError( + "google-cloud-bigquery is not installed. Install with: pip install 'eval-protocol[bigquery]'" + ) + # Avoid strict typing on optional dependency + self.client = _bigquery_runtime.Client(**client_args) # type: ignore[no-untyped-call, assignment] except DefaultCredentialsError as e: logger.error("Failed to authenticate with BigQuery: %s", e) @@ -139,7 +147,9 @@ def get_evaluation_rows( """ try: # Configure query job - job_config = bigquery.QueryJobConfig() + if _bigquery_runtime is None: + raise RuntimeError("BigQuery runtime not available") + job_config = _bigquery_runtime.QueryJobConfig() # type: ignore[no-untyped-call] if query_params: job_config.query_parameters = query_params if self.location: diff --git a/eval_protocol/adapters/langchain.py b/eval_protocol/adapters/langchain.py index 15c6cc90..df6818a5 100644 --- a/eval_protocol/adapters/langchain.py +++ b/eval_protocol/adapters/langchain.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import List +from typing import Any, Dict, List, Optional from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage @@ -49,10 +49,10 @@ def serialize_lc_message_to_ep(msg: BaseMessage) -> Message: parts.append(item) content = "\n".join(parts) - tool_calls_payload = None + tool_calls_payload: Optional[List[Dict[str, Any]]] = None - def _normalize_tool_calls(tc_list: list) -> list[dict]: - mapped: List[dict] = [] + def _normalize_tool_calls(tc_list: List[Any]) -> List[Dict[str, Any]]: + mapped: List[Dict[str, Any]] = [] for call in tc_list: if not isinstance(call, dict): continue @@ -104,8 +104,13 @@ def _normalize_tool_calls(tc_list: list) -> list[dict]: if collected: reasoning_content = "\n\n".join([s for s in collected if s]) or None + # Message.tool_calls expects List[ChatCompletionMessageToolCall] | None. + # We pass through Dicts at runtime but avoid type error by casting. ep_msg = Message( - role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content + role="assistant", + content=content, + tool_calls=tool_calls_payload, # type: ignore[arg-type] + reasoning_content=reasoning_content, ) _dbg_print( "[EP-Ser] -> EP Message:", diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 0061b983..cc51e1b2 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -6,7 +6,7 @@ import logging from datetime import datetime -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, cast from eval_protocol.models import EvaluationRow, InputMetadata, Message @@ -63,7 +63,7 @@ def __init__( if not LANGFUSE_AVAILABLE: raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'") - self.client = Langfuse(public_key=public_key, secret_key=secret_key, host=host) + self.client = cast(Any, Langfuse)(public_key=public_key, secret_key=secret_key, host=host) self.project_id = project_id def get_evaluation_rows( diff --git a/eval_protocol/benchmarks/test_gpqa.py b/eval_protocol/benchmarks/test_gpqa.py index 2cbda574..e2c449c3 100644 --- a/eval_protocol/benchmarks/test_gpqa.py +++ b/eval_protocol/benchmarks/test_gpqa.py @@ -58,6 +58,8 @@ def _extract_abcd_letter(text: str) -> str | None: def _strip_gt_messages(msgs: list[Message]) -> list[Message]: + # assert that all the messages just have a plain .content string field + assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field" return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] diff --git a/eval_protocol/benchmarks/test_tau_bench_retail.py b/eval_protocol/benchmarks/test_tau_bench_retail.py index f7240df8..d26d2675 100644 --- a/eval_protocol/benchmarks/test_tau_bench_retail.py +++ b/eval_protocol/benchmarks/test_tau_bench_retail.py @@ -188,6 +188,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow: task = Task( id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler") ) # id and user_scenario are required for the Task type but not used in calculating reward + assert task.evaluation_criteria is not None, "Task evaluation criteria is None" if RewardType.DB in task.evaluation_criteria.reward_basis: env_reward_info = EnvironmentEvaluator.calculate_reward( diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 0454cf14..caedbf67 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -257,22 +257,21 @@ def load_metric_folder(self, metric_name, folder_path): for keyword in decorator_node.keywords: if keyword.arg == "requirements": if isinstance(keyword.value, ast.List): - reqs = [] + reqs: List[str] = [] for elt in keyword.value.elts: - if isinstance(elt, ast.Constant) and isinstance( - elt.value, str - ): # Python 3.8+ - reqs.append(elt.value) + if isinstance(elt, ast.Constant): # Python 3.8+ + if isinstance(elt.value, str): + reqs.append(cast(str, elt.value)) elif isinstance(elt, ast.Str): # Python < 3.8 - reqs.append(elt.s) + reqs.append(cast(str, elt.s)) if reqs: metric_requirements_list = cast(List[str], reqs) elif isinstance(keyword.value, ast.Constant) and isinstance( keyword.value.value, str ): # Python 3.8+ (single req string) - metric_requirements_list = [keyword.value.value] + metric_requirements_list = [cast(str, keyword.value.value)] elif isinstance(keyword.value, ast.Str): # Python < 3.8 (single req string) - metric_requirements_list = [keyword.value.s] + metric_requirements_list = [cast(str, keyword.value.s)] break if metric_requirements_list: break diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index a6fcd53d..f0c85ac6 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -441,9 +441,19 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) # Extract data plane results (observation only) if tool_result.content and len(tool_result.content) > 0: content = tool_result.content[0] - if hasattr(content, "text"): + # Safely attempt to read a "text" attribute if present across content types + text_attr = getattr(content, "text", None) + if isinstance(text_attr, str): + content_text = text_attr + elif isinstance(text_attr, list): + # text can also be an array of parts with optional .text fields + content_text = "".join([getattr(p, "text", "") for p in text_attr]) + else: + content_text = None + + if isinstance(content_text, str): # Fix: Handle empty or invalid JSON responses gracefully - if not content.text or content.text.strip() == "": + if content_text.strip() == "": logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}") observation = { "observation": "empty_response", @@ -451,14 +461,14 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) } else: try: - observation = json.loads(content.text) + observation = json.loads(content_text) except json.JSONDecodeError as e: logger.warning( - f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}" + f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}" ) # Create a structured response from the raw text observation = { - "observation": content.text, + "observation": content_text, "session_id": session.session_id, "error": "invalid_json_response", } diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index 38fd51e4..1adb9b95 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -117,10 +117,19 @@ def _setup_litellm_caching( logger.info("🗄️ Initialized disk caching") elif cache_type == "s3": - from litellm.caching.s3_cache import S3Cache - - litellm.cache = S3Cache() - logger.info("🗄️ Initialized S3 caching") + try: + from litellm.caching.s3_cache import S3Cache + + # Some versions require positional or named 's3_bucket_name' + s3_bucket_name = os.getenv("LITELLM_S3_BUCKET") + if not s3_bucket_name: + raise ValueError("Missing LITELLM_S3_BUCKET for S3 cache") + # Use explicit arg name expected by basedpyright + litellm.cache = S3Cache(s3_bucket_name=s3_bucket_name) + logger.info("🗄️ Initialized S3 caching for bucket %s", s3_bucket_name) + except Exception as e: + logger.warning(f"Failed to initialize S3 cache ({e}); falling back to in-memory cache") + litellm.cache = Cache() except Exception as e: logger.warning(f"Failed to setup {cache_type} caching: {e}. Falling back to in-memory cache.") @@ -147,7 +156,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]: clean_messages.append(clean_msg) return clean_messages - async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: + async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> Dict[str, Any]: """ Make an LLM API call with retry logic and caching. @@ -162,7 +171,7 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: clean_messages = self._clean_messages_for_api(messages) # Prepare request parameters - request_params = { + request_params: Dict[str, Any] = { "messages": clean_messages, "temperature": self.temperature, "max_tokens": self.max_tokens, @@ -188,7 +197,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: response = await acompletion(model=self.model_id, **request_params) # Log cache hit/miss for monitoring - cache_hit = getattr(response, "_hidden_params", {}).get("cache_hit", False) + hidden = getattr(response, "_hidden_params", {}) + cache_hit = hidden.get("cache_hit", False) if isinstance(hidden, dict) else False if cache_hit: logger.debug(f"🎯 Cache hit for model: {self.model_id}") else: @@ -199,31 +209,34 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: "choices": [ { "message": { - "role": response.choices[0].message.role, - "content": response.choices[0].message.content, + "role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"), + "content": getattr(getattr(response.choices[0], "message", object()), "content", None), "tool_calls": ( [ { - "id": tc.id, - "type": tc.type, + "id": getattr(tc, "id", None), + "type": getattr(tc, "type", "function"), "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, + "name": getattr(getattr(tc, "function", None), "name", "tool"), + "arguments": getattr(getattr(tc, "function", None), "arguments", "{}"), }, } - for tc in (response.choices[0].message.tool_calls or []) + for tc in ( + getattr(getattr(response.choices[0], "message", object()), "tool_calls", []) + or [] + ) ] - if response.choices[0].message.tool_calls + if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None) else [] ), }, - "finish_reason": response.choices[0].finish_reason, + "finish_reason": getattr(response.choices[0], "finish_reason", None), } ], "usage": { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, + "prompt_tokens": getattr(getattr(response, "usage", {}), "prompt_tokens", 0), + "completion_tokens": getattr(getattr(response, "usage", {}), "completion_tokens", 0), + "total_tokens": getattr(getattr(response, "usage", {}), "total_tokens", 0), }, } diff --git a/eval_protocol/mcp_servers/tau2/__init__.py b/eval_protocol/mcp_servers/tau2/__init__.py index 8076b435..81a9da36 100644 --- a/eval_protocol/mcp_servers/tau2/__init__.py +++ b/eval_protocol/mcp_servers/tau2/__init__.py @@ -12,8 +12,9 @@ def get_server_script_path() -> str: """Get the path to the tau2 MCP server script.""" try: - # Try to get from installed package - with importlib.resources.as_file(importlib.resources.files(__package__) / "server.py") as server_path: + # Try to get from installed package. __package__ can be None during some tooling. + package = __package__ if __package__ is not None else __name__ + with importlib.resources.as_file(importlib.resources.files(package) / "server.py") as server_path: return str(server_path) except (ImportError, FileNotFoundError): # Fallback for development environment diff --git a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py index ec7c3944..ed7e77e7 100644 --- a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py +++ b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py @@ -699,9 +699,9 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): @reward_function def tau2_airline_eval( messages: List[Message], - nl_assertions: List[str] = None, - communicate_info: List[str] = None, - actions: List[dict] = None, + nl_assertions: Optional[List[str]] = None, + communicate_info: Optional[List[str]] = None, + actions: Optional[List[dict]] = None, **kwargs, ) -> EvaluateResult: """ @@ -726,6 +726,7 @@ def tau2_airline_eval( for msg in messages: role = msg.role content = msg.content + assert isinstance(content, str), "Content must be a string" if role == "system": trajectory_objects.append(SystemMessage(role=role, content=content)) diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index aa602c60..e036af3d 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Any, AsyncIterator, List, Optional, Union +from typing import Any, AsyncIterator, List, Optional, Union, Dict from mcp.types import CallToolResult, TextContent from openai import NOT_GIVEN, NotGiven @@ -35,9 +35,31 @@ async def setup(self): if self.mcp_client: await self.mcp_client.connect_to_servers() - async def _get_tools(self) -> Optional[List[ChatCompletionToolParam]]: + async def _get_tools(self) -> Optional[List[dict[str, Any]]]: if self.evaluation_row.tools is None: - self.evaluation_row.tools = await self.mcp_client.get_available_tools() if self.mcp_client else None + if self.mcp_client: + raw_tools = await self.mcp_client.get_available_tools() + tools_dicts: List[dict[str, Any]] = [] + for t in raw_tools or []: + if isinstance(t, dict): + # Already a dict-like structure + tools_dicts.append(t) + continue + # Fallback: extract attributes from OpenAI types + tool_type = getattr(t, "type", "function") + func = getattr(t, "function", None) + name = getattr(func, "name", None) + params = getattr(func, "parameters", None) + if hasattr(params, "model_dump"): + params_payload = params.model_dump() + elif isinstance(params, dict): + params_payload = params + else: + params_payload = {} + tools_dicts.append({"type": tool_type, "function": {"name": name, "parameters": params_payload}}) + self.evaluation_row.tools = tools_dicts + else: + self.evaluation_row.tools = None return self.evaluation_row.tools @property @@ -48,7 +70,7 @@ def append_message_and_log(self, message: Message): self.messages.append(message) self.logger.log(self.evaluation_row) - async def call_agent(self) -> str: + async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPartTextParam]]]: """ Call the assistant with the user query. """ @@ -66,7 +88,7 @@ async def call_agent(self) -> str: tool_args_dict = json.loads(tool_args) # Create a task for each tool call - task = self._execute_tool_call(tool_call_id, tool_name, tool_args_dict) + task = asyncio.create_task(self._execute_tool_call(tool_call_id, tool_name, tool_args_dict)) tool_tasks.append(task) # Execute all tool calls in parallel @@ -81,14 +103,58 @@ async def call_agent(self) -> str: return await self.call_agent() return message.content - async def _call_model(self, messages: list[Message], tools: Optional[list[ChatCompletionToolParam]]) -> Message: - messages = [message.model_dump() if hasattr(message, "model_dump") else message for message in messages] - tools = [{"function": tool["function"].model_dump(), "type": "function"} for tool in tools] if tools else [] - response = await self._policy._make_llm_call(messages=messages, tools=tools) + async def _call_model(self, messages: list[Message], tools: Optional[List[dict[str, Any]]]) -> Message: + # Convert Message models to plain dicts for LLM call + messages_payload: List[Dict[str, Any]] = [ + message.model_dump() if hasattr(message, "model_dump") else message # type: ignore[misc] + for message in messages + ] + # Normalize tool definitions into OpenAI-compatible dicts + payload_tools: List[Dict[str, Any]] = [] + for tool in tools or []: + if isinstance(tool, dict): + fn = tool.get("function") + if hasattr(fn, "model_dump"): + fn_payload = fn.model_dump() + elif isinstance(fn, dict): + fn_payload = fn + else: + # Best effort fallback + name = getattr(fn, "name", None) + params = getattr(fn, "parameters", None) + if hasattr(params, "model_dump"): + params_payload = params.model_dump() + elif isinstance(params, dict): + params_payload = params + else: + params_payload = {} + fn_payload = {"name": name, "parameters": params_payload} + payload_tools.append({"type": tool.get("type", "function"), "function": fn_payload}) + else: + # Attribute-based fallback + tool_type = getattr(tool, "type", "function") + func = getattr(tool, "function", None) + name = getattr(func, "name", None) + params = getattr(func, "parameters", None) + if hasattr(params, "model_dump"): + params_payload = params.model_dump() + elif isinstance(params, dict): + params_payload = params + else: + params_payload = {} + payload_tools.append({"type": tool_type, "function": {"name": name, "parameters": params_payload}}) + + response = await self._policy._make_llm_call(messages=messages_payload, tools=payload_tools) + # Coerce content to a string to align with our Message model type expectations + raw_content = response["choices"][0]["message"].get("content") + if isinstance(raw_content, list): + content_for_model = "".join([getattr(p, "text", str(p)) for p in raw_content]) + else: + content_for_model = raw_content return Message( role=response["choices"][0]["message"]["role"], - content=response["choices"][0]["message"]["content"], - tool_calls=response["choices"][0]["message"]["tool_calls"], + content=content_for_model, + tool_calls=response["choices"][0]["message"].get("tool_calls"), ) async def _execute_tool_call( @@ -98,16 +164,22 @@ async def _execute_tool_call( Execute a single tool call and return the tool_call_id and content. This method is designed to be used with asyncio.gather() for parallel execution. """ + assert self.mcp_client is not None, "MCP client is not initialized" tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict) content = self._get_content_from_tool_result(tool_result) return tool_call_id, content def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]: - if tool_result.structuredContent: + if getattr(tool_result, "structuredContent", None): return [TextContent(text=json.dumps(tool_result.structuredContent), type="text")] - if not all(isinstance(content, TextContent) for content in tool_result.content): - raise NotImplementedError("Non-text content is not supported yet") - return tool_result.content + normalized: List[TextContent] = [] + for content in getattr(tool_result, "content", []) or []: + if isinstance(content, TextContent): + normalized.append(content) + else: + text_val = getattr(content, "text", str(content)) + normalized.append(TextContent(text=str(text_val), type="text")) + return normalized def _format_tool_message_content( self, content: List[TextContent] diff --git a/eval_protocol/rewards/json_schema.py b/eval_protocol/rewards/json_schema.py index 06f2fc5e..c3b7e5ae 100644 --- a/eval_protocol/rewards/json_schema.py +++ b/eval_protocol/rewards/json_schema.py @@ -304,7 +304,15 @@ def json_schema_reward_with_llm_judge( if "error" in schema_result.metrics: return schema_result last_message = messages[-1] - content = last_message.get("content", "") + assert last_message is not None, "Last message is None" + # Support both dict-shaped messages and pydantic Message objects + if isinstance(last_message, dict): + content = last_message.get("content", "") + else: + try: + content = getattr(last_message, "content", "") + except Exception: + content = "" json_str_from_msg = "" try: pattern = r"```(?:json)?\s*([\s\S]*?)```"