From d94b2e49e806343a1f8c1f0271623c75496c988f Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Sun, 31 Aug 2025 20:22:34 +0800 Subject: [PATCH 1/2] fix pyright round 3 --- eval_protocol/cli_commands/deploy.py | 16 ++++++++---- eval_protocol/evaluation.py | 6 +++-- eval_protocol/get_pep440_version.py | 10 +++++++- eval_protocol/mcp/client/connection.py | 28 ++++++++++++--------- eval_protocol/mcp/clients.py | 12 +++++---- eval_protocol/rewards/code_execution.py | 2 +- eval_protocol/typed_interface.py | 9 +++---- eval_protocol/utils/batch_transformation.py | 2 +- eval_protocol/utils/vite_server.py | 4 +-- 9 files changed, 55 insertions(+), 34 deletions(-) diff --git a/eval_protocol/cli_commands/deploy.py b/eval_protocol/cli_commands/deploy.py index 607751b2..1ae0313b 100644 --- a/eval_protocol/cli_commands/deploy.py +++ b/eval_protocol/cli_commands/deploy.py @@ -29,7 +29,7 @@ import socket import subprocess - def start_process(command, log_path, env=None): + def _fallback_start_process(command, log_path, env=None): """Fallback process starter.""" try: with open(log_path, "w") as log_file: @@ -39,7 +39,7 @@ def start_process(command, log_path, env=None): print(f"Error starting process: {e}") return None - def stop_process(pid): + def _fallback_stop_process(pid): """Fallback process stopper.""" try: import os @@ -48,15 +48,21 @@ def stop_process(pid): except Exception: pass - def start_serveo_and_get_url(local_port, log_path): + def _fallback_start_serveo_and_get_url(local_port, log_path): """Fallback serveo tunnel - returns None to indicate unavailable.""" print("Serveo tunneling not available - development module not found") return None, None - def start_ngrok_and_get_url(local_port, log_path): + def _fallback_start_ngrok_and_get_url(local_port, log_path): """Fallback ngrok tunnel - returns None to indicate unavailable.""" print("ngrok tunneling not available - development module not found") return None, None + + # Expose unified names using fallbacks + start_process = _fallback_start_process + stop_process = _fallback_stop_process + start_serveo_and_get_url = _fallback_start_serveo_and_get_url + start_ngrok_and_get_url = _fallback_start_ngrok_and_get_url else: # Wrap imported helpers to present consistent, simple signatures used below def start_process(command, log_path, env=None): @@ -66,7 +72,7 @@ def stop_process(pid): return _stop_process(pid) def start_serveo_and_get_url(local_port, log_path): - return _start_serveo_and_get_url(local_port=local_port, log_path=log_path) + return _start_serveo_and_get_url(local_port=local_port, log_file_path=log_path) def start_ngrok_and_get_url(local_port, log_path): return _start_ngrok_and_get_url(local_port=local_port, ngrok_log_file=log_path) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index fe58bb8a..0454cf14 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -7,7 +7,7 @@ import time import types from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast if TYPE_CHECKING: # For type checking only @@ -173,6 +173,8 @@ def __init__( self.description = "" self.display_name = "" self.api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") + # Optional requirements string for multi-metric mode (when loaded differently) + self._loaded_multi_metric_requirements_str: Optional[str] = None if self.ts_mode_config: python_code = self.ts_mode_config.get("python_code") @@ -264,7 +266,7 @@ def load_metric_folder(self, metric_name, folder_path): elif isinstance(elt, ast.Str): # Python < 3.8 reqs.append(elt.s) if reqs: - metric_requirements_list = reqs + metric_requirements_list = cast(List[str], reqs) elif isinstance(keyword.value, ast.Constant) and isinstance( keyword.value.value, str ): # Python 3.8+ (single req string) diff --git a/eval_protocol/get_pep440_version.py b/eval_protocol/get_pep440_version.py index 8ebb33c9..c3f939d9 100644 --- a/eval_protocol/get_pep440_version.py +++ b/eval_protocol/get_pep440_version.py @@ -1,7 +1,15 @@ # Cache for PEP 440 version string import subprocess -_version_cache = {"version": None, "base_version": None} +from typing import Dict, Optional, TypedDict + + +class _VersionCache(TypedDict): + version: Optional[str] + base_version: Optional[str] + + +_version_cache: _VersionCache = {"version": None, "base_version": None} def get_pep440_version(base_version=None): diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index 6916cb08..d8c13f3b 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -306,14 +306,15 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any resource_content = await mcp_session.read_resource(initial_state_resource.uri) # Handle the new ResourceContents format - if hasattr(resource_content, "text"): + text_value = getattr(resource_content, "text", None) + if text_value is not None: try: - initial_observation = json.loads(resource_content.text) + initial_observation = json.loads(text_value) logger.info( f"Session {session.session_id}: ✅ Successfully parsed JSON initial state with grid_layout: {initial_observation.get('grid_layout', 'N/A')[:20]}..." ) except json.JSONDecodeError: - initial_observation = {"observation": resource_content.text} + initial_observation = {"observation": text_value} elif ( hasattr(resource_content, "contents") and resource_content.contents @@ -321,11 +322,12 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any ): # Fallback to old format for backward compatibility content = resource_content.contents[0] - if hasattr(content, "text"): + content_text = getattr(content, "text", None) + if content_text is not None: try: - initial_observation = json.loads(content.text) + initial_observation = json.loads(content_text) except json.JSONDecodeError: - initial_observation = {"observation": content.text} + initial_observation = {"observation": content_text} else: initial_observation = {"observation": str(resource_content)} else: @@ -359,11 +361,12 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any ) # Handle the new ResourceContents format - if hasattr(resource_content, "text"): + text_value_2 = getattr(resource_content, "text", None) + if text_value_2 is not None: try: - initial_observation = json.loads(resource_content.text) + initial_observation = json.loads(text_value_2) except json.JSONDecodeError: - initial_observation = {"observation": resource_content.text} + initial_observation = {"observation": text_value_2} elif ( hasattr(resource_content, "contents") and resource_content.contents @@ -371,11 +374,12 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any ): # Fallback to old format for backward compatibility content = resource_content.contents[0] - if hasattr(content, "text"): + content_text_2 = getattr(content, "text", None) + if content_text_2 is not None: try: - initial_observation = json.loads(content.text) + initial_observation = json.loads(content_text_2) except json.JSONDecodeError: - initial_observation = {"observation": content.text} + initial_observation = {"observation": content_text_2} else: initial_observation = {"observation": str(content)} else: diff --git a/eval_protocol/mcp/clients.py b/eval_protocol/mcp/clients.py index b1287fe1..c7a95f3c 100644 --- a/eval_protocol/mcp/clients.py +++ b/eval_protocol/mcp/clients.py @@ -29,7 +29,8 @@ def __init__(self, intermediary_server_url: str): async def connect(self): """Establishes connection and MCP session.""" - if self._mcp_session is not None and not self._mcp_session.is_closed: + # ClientSession does not expose a stable public `is_closed`; consider session presence sufficient + if self._mcp_session is not None: logger.debug("Already connected.") return @@ -97,13 +98,14 @@ async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[ if mcp_response.isError or not mcp_response.content or not hasattr(mcp_response.content[0], "text"): error_message = f"Tool call '{tool_name}' to intermediary failed." if mcp_response.isError and mcp_response.content and hasattr(mcp_response.content[0], "text"): - error_message += f" Details: {mcp_response.content[0].text}" + error_text = getattr(mcp_response.content[0], "text", "") + error_message += f" Details: {error_text}" elif mcp_response.isError: error_message += " No detailed error message in content." logger.error(error_message) try: if mcp_response.content and hasattr(mcp_response.content[0], "text"): - parsed_error = json.loads(mcp_response.content[0].text) + parsed_error = json.loads(getattr(mcp_response.content[0], "text", "")) if isinstance(parsed_error, dict) and "error" in parsed_error: raise RuntimeError(f"{error_message} Nested error: {parsed_error['error']}") except (json.JSONDecodeError, TypeError): @@ -111,12 +113,12 @@ async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[ raise RuntimeError(error_message) try: - parsed_result = json.loads(mcp_response.content[0].text) + parsed_result = json.loads(getattr(mcp_response.content[0], "text", "")) logger.debug(f"Parsed JSON result from intermediary for '{tool_name}': {parsed_result}") return parsed_result except json.JSONDecodeError as e: logger.error( - f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {mcp_response.content[0].text}. Error: {e}" + f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {getattr(mcp_response.content[0], 'text', '')}. Error: {e}" ) raise RuntimeError(f"Failed to parse JSON response from intermediary tool '{tool_name}'.") diff --git a/eval_protocol/rewards/code_execution.py b/eval_protocol/rewards/code_execution.py index 2723d5b7..e918eedd 100644 --- a/eval_protocol/rewards/code_execution.py +++ b/eval_protocol/rewards/code_execution.py @@ -80,7 +80,7 @@ def extract_code_blocks(text: str, language: Optional[str] = None) -> List[Dict[ List of dictionaries with "code" and "language" keys """ pattern = r"```(\w*)\n([\s\S]*?)\n```" - matches = re.findall(pattern, text) + matches = re.findall(pattern, text or "") code_blocks = [] verbose_patterns_removed = [] diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 696bd538..78a133c6 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -81,10 +81,9 @@ def decorator(func: F) -> F: has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) if not has_var_keyword: - raise ValueError( - f"Function '{func.__name__}' must accept **kwargs parameter. " - f"Please add '**kwargs' to the function signature." - ) + # Return a wrapper that preserves the original signature, but adds **kwargs dynamically + # instead of raising at decoration time. + pass # Setup resources once when the decorator is applied resource_managers = {} @@ -113,7 +112,7 @@ def _is_list_of_message_annotation(annotation: Any) -> bool: inner = non_none[0] inner_origin = get_origin(inner) inner_args = get_args(inner) - return inner_origin in (list, List) and inner_args and inner_args[0] == Message + return (inner_origin in (list, List)) and bool(inner_args) and (inner_args[0] == Message) return False def _prepare_final_args(*args: Any, **kwargs: Any): diff --git a/eval_protocol/utils/batch_transformation.py b/eval_protocol/utils/batch_transformation.py index b250513a..53949af6 100644 --- a/eval_protocol/utils/batch_transformation.py +++ b/eval_protocol/utils/batch_transformation.py @@ -16,7 +16,7 @@ def transform_n_variant_jsonl_to_batch_format( request_id_field: str = "request_id", response_id_field: str = "response_id", messages_field: str = "full_conversation_history", - fallback_messages_fields: List[str] = None, + fallback_messages_fields: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ Transform N-variant generation JSONL output into batch evaluation format. diff --git a/eval_protocol/utils/vite_server.py b/eval_protocol/utils/vite_server.py index 02eef31d..8c91cadb 100644 --- a/eval_protocol/utils/vite_server.py +++ b/eval_protocol/utils/vite_server.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import AsyncGenerator, Callable, Optional +from typing import AsyncGenerator, Callable, Optional, Any import uvicorn from fastapi import FastAPI, HTTPException @@ -32,7 +32,7 @@ def __init__( host: str = "localhost", port: int = 8000, index_file: str = "index.html", - lifespan: Optional[Callable[[FastAPI], AsyncGenerator[None, None]]] = None, + lifespan: Optional[Callable[[FastAPI], Any]] = None, ): self.build_dir = Path(build_dir) self.host = host From 065adbdeb49e4aa5f3d8c3ec4184688d07b8ea93 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Sun, 31 Aug 2025 20:50:55 +0800 Subject: [PATCH 2/2] fix tests --- eval_protocol/rewards/code_execution.py | 12 ++++++++++-- eval_protocol/rewards/deepcoder_reward.py | 7 ++++++- .../rewards/list_comparison_math_reward.py | 7 ++++++- .../rewards/multiple_choice_math_reward.py | 14 ++++++++++++-- eval_protocol/typed_interface.py | 6 +++--- 5 files changed, 37 insertions(+), 9 deletions(-) diff --git a/eval_protocol/rewards/code_execution.py b/eval_protocol/rewards/code_execution.py index e918eedd..38c7189a 100644 --- a/eval_protocol/rewards/code_execution.py +++ b/eval_protocol/rewards/code_execution.py @@ -1098,7 +1098,15 @@ def fractional_code_reward( }, ) - code_blocks = extract_code_blocks(response_content, language) + # Normalize content to string; Message.content may be str or list of content parts + _last_content = response_content + response_content_str = ( + _last_content + if isinstance(_last_content, str) + else "".join([getattr(p, "text", "") for p in (_last_content or [])]) + ) + + code_blocks = extract_code_blocks(response_content_str, language) if not code_blocks: return EvaluateResult( @@ -1617,7 +1625,7 @@ class Capturing(list): def __enter__(self): self._stdout = sys.stdout sys.stdout = self._stringio = StringIO() - self._stringio.close = lambda x: None + self._stringio.close = lambda: None return self def __exit__(self, *args): diff --git a/eval_protocol/rewards/deepcoder_reward.py b/eval_protocol/rewards/deepcoder_reward.py index ebdc44bb..37348cb0 100644 --- a/eval_protocol/rewards/deepcoder_reward.py +++ b/eval_protocol/rewards/deepcoder_reward.py @@ -73,7 +73,12 @@ def deepcoder_code_reward( is_score_valid=False, ) - assistant_content = messages[-1].content + assistant_content_raw = messages[-1].content + assistant_content = ( + assistant_content_raw + if isinstance(assistant_content_raw, str) + else "".join([getattr(p, "text", "") for p in (assistant_content_raw or [])]) + ) test_cases = ground_truth code_blocks = extract_code_blocks(assistant_content, language) diff --git a/eval_protocol/rewards/list_comparison_math_reward.py b/eval_protocol/rewards/list_comparison_math_reward.py index ceaef012..5ee34fe2 100644 --- a/eval_protocol/rewards/list_comparison_math_reward.py +++ b/eval_protocol/rewards/list_comparison_math_reward.py @@ -127,7 +127,12 @@ def list_comparison_math_reward( }, ) - gen_content = messages[-1].content + gen_content_raw = messages[-1].content + gen_content = ( + gen_content_raw + if isinstance(gen_content_raw, str) + else "".join([getattr(p, "text", "") for p in (gen_content_raw or [])]) + ) orig_content = ground_truth if not gen_content: diff --git a/eval_protocol/rewards/multiple_choice_math_reward.py b/eval_protocol/rewards/multiple_choice_math_reward.py index 5768de80..a1e9c1df 100644 --- a/eval_protocol/rewards/multiple_choice_math_reward.py +++ b/eval_protocol/rewards/multiple_choice_math_reward.py @@ -134,7 +134,12 @@ def multiple_choice_math_reward( if messages and len(messages) > 0: gen_response_message = messages[-1] if gen_response_message.role == "assistant": - gen_content = gen_response_message.content or "" + raw_gen_content = gen_response_message.content + gen_content = ( + raw_gen_content + if isinstance(raw_gen_content, str) + else "".join([getattr(p, "text", "") for p in (raw_gen_content or [])]) + ) if not gen_content: metrics["error_generated_message"] = MetricResult( @@ -152,7 +157,12 @@ def multiple_choice_math_reward( if ground_truth and len(ground_truth) > 0: orig_response_message = ground_truth[0] if orig_response_message.role == "assistant": - orig_content = orig_response_message.content or "" + raw_orig_content = orig_response_message.content + orig_content = ( + raw_orig_content + if isinstance(raw_orig_content, str) + else "".join([getattr(p, "text", "") for p in (raw_orig_content or [])]) + ) if not orig_content: metrics["error_original_message"] = MetricResult( diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 78a133c6..97dafdf3 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -81,9 +81,9 @@ def decorator(func: F) -> F: has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) if not has_var_keyword: - # Return a wrapper that preserves the original signature, but adds **kwargs dynamically - # instead of raising at decoration time. - pass + raise ValueError( + f"Function '{func.__name__}' must accept **kwargs parameter. Please add '**kwargs' to the function signature." + ) # Setup resources once when the decorator is applied resource_managers = {}