Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions eval_protocol/cli_commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import socket
import subprocess

def start_process(command, log_path, env=None):
def _fallback_start_process(command, log_path, env=None):
"""Fallback process starter."""
try:
with open(log_path, "w") as log_file:
Expand All @@ -39,7 +39,7 @@ def start_process(command, log_path, env=None):
print(f"Error starting process: {e}")
return None

def stop_process(pid):
def _fallback_stop_process(pid):
"""Fallback process stopper."""
try:
import os
Expand All @@ -48,15 +48,21 @@ def stop_process(pid):
except Exception:
pass

def start_serveo_and_get_url(local_port, log_path):
def _fallback_start_serveo_and_get_url(local_port, log_path):
"""Fallback serveo tunnel - returns None to indicate unavailable."""
print("Serveo tunneling not available - development module not found")
return None, None

def start_ngrok_and_get_url(local_port, log_path):
def _fallback_start_ngrok_and_get_url(local_port, log_path):
"""Fallback ngrok tunnel - returns None to indicate unavailable."""
print("ngrok tunneling not available - development module not found")
return None, None

# Expose unified names using fallbacks
start_process = _fallback_start_process
stop_process = _fallback_stop_process
start_serveo_and_get_url = _fallback_start_serveo_and_get_url
start_ngrok_and_get_url = _fallback_start_ngrok_and_get_url
else:
# Wrap imported helpers to present consistent, simple signatures used below
def start_process(command, log_path, env=None):
Expand All @@ -66,7 +72,7 @@ def stop_process(pid):
return _stop_process(pid)

def start_serveo_and_get_url(local_port, log_path):
return _start_serveo_and_get_url(local_port=local_port, log_path=log_path)
return _start_serveo_and_get_url(local_port=local_port, log_file_path=log_path)

def start_ngrok_and_get_url(local_port, log_path):
return _start_ngrok_and_get_url(local_port=local_port, ngrok_log_file=log_path)
Expand Down
6 changes: 4 additions & 2 deletions eval_protocol/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import types
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast

if TYPE_CHECKING:
# For type checking only
Expand Down Expand Up @@ -173,6 +173,8 @@ def __init__(
self.description = ""
self.display_name = ""
self.api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai")
# Optional requirements string for multi-metric mode (when loaded differently)
self._loaded_multi_metric_requirements_str: Optional[str] = None

if self.ts_mode_config:
python_code = self.ts_mode_config.get("python_code")
Expand Down Expand Up @@ -264,7 +266,7 @@ def load_metric_folder(self, metric_name, folder_path):
elif isinstance(elt, ast.Str): # Python < 3.8
reqs.append(elt.s)
if reqs:
metric_requirements_list = reqs
metric_requirements_list = cast(List[str], reqs)
elif isinstance(keyword.value, ast.Constant) and isinstance(
keyword.value.value, str
): # Python 3.8+ (single req string)
Expand Down
10 changes: 9 additions & 1 deletion eval_protocol/get_pep440_version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Cache for PEP 440 version string
import subprocess

_version_cache = {"version": None, "base_version": None}
from typing import Dict, Optional, TypedDict


class _VersionCache(TypedDict):
version: Optional[str]
base_version: Optional[str]


_version_cache: _VersionCache = {"version": None, "base_version": None}


def get_pep440_version(base_version=None):
Expand Down
28 changes: 16 additions & 12 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,26 +306,28 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any
resource_content = await mcp_session.read_resource(initial_state_resource.uri)

# Handle the new ResourceContents format
if hasattr(resource_content, "text"):
text_value = getattr(resource_content, "text", None)
if text_value is not None:
try:
initial_observation = json.loads(resource_content.text)
initial_observation = json.loads(text_value)
logger.info(
f"Session {session.session_id}: ✅ Successfully parsed JSON initial state with grid_layout: {initial_observation.get('grid_layout', 'N/A')[:20]}..."
)
except json.JSONDecodeError:
initial_observation = {"observation": resource_content.text}
initial_observation = {"observation": text_value}
elif (
hasattr(resource_content, "contents")
and resource_content.contents
and len(resource_content.contents) > 0
):
# Fallback to old format for backward compatibility
content = resource_content.contents[0]
if hasattr(content, "text"):
content_text = getattr(content, "text", None)
if content_text is not None:
try:
initial_observation = json.loads(content.text)
initial_observation = json.loads(content_text)
except json.JSONDecodeError:
initial_observation = {"observation": content.text}
initial_observation = {"observation": content_text}
else:
initial_observation = {"observation": str(resource_content)}
else:
Expand Down Expand Up @@ -359,23 +361,25 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any
)

# Handle the new ResourceContents format
if hasattr(resource_content, "text"):
text_value_2 = getattr(resource_content, "text", None)
if text_value_2 is not None:
try:
initial_observation = json.loads(resource_content.text)
initial_observation = json.loads(text_value_2)
except json.JSONDecodeError:
initial_observation = {"observation": resource_content.text}
initial_observation = {"observation": text_value_2}
elif (
hasattr(resource_content, "contents")
and resource_content.contents
and len(resource_content.contents) > 0
):
# Fallback to old format for backward compatibility
content = resource_content.contents[0]
if hasattr(content, "text"):
content_text_2 = getattr(content, "text", None)
if content_text_2 is not None:
try:
initial_observation = json.loads(content.text)
initial_observation = json.loads(content_text_2)
except json.JSONDecodeError:
initial_observation = {"observation": content.text}
initial_observation = {"observation": content_text_2}
else:
initial_observation = {"observation": str(content)}
else:
Expand Down
12 changes: 7 additions & 5 deletions eval_protocol/mcp/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, intermediary_server_url: str):

async def connect(self):
"""Establishes connection and MCP session."""
if self._mcp_session is not None and not self._mcp_session.is_closed:
# ClientSession does not expose a stable public `is_closed`; consider session presence sufficient
if self._mcp_session is not None:
logger.debug("Already connected.")
return

Expand Down Expand Up @@ -97,26 +98,27 @@ async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[
if mcp_response.isError or not mcp_response.content or not hasattr(mcp_response.content[0], "text"):
error_message = f"Tool call '{tool_name}' to intermediary failed."
if mcp_response.isError and mcp_response.content and hasattr(mcp_response.content[0], "text"):
error_message += f" Details: {mcp_response.content[0].text}"
error_text = getattr(mcp_response.content[0], "text", "")
error_message += f" Details: {error_text}"
elif mcp_response.isError:
error_message += " No detailed error message in content."
logger.error(error_message)
try:
if mcp_response.content and hasattr(mcp_response.content[0], "text"):
parsed_error = json.loads(mcp_response.content[0].text)
parsed_error = json.loads(getattr(mcp_response.content[0], "text", ""))
if isinstance(parsed_error, dict) and "error" in parsed_error:
raise RuntimeError(f"{error_message} Nested error: {parsed_error['error']}")
except (json.JSONDecodeError, TypeError):
pass
raise RuntimeError(error_message)

try:
parsed_result = json.loads(mcp_response.content[0].text)
parsed_result = json.loads(getattr(mcp_response.content[0], "text", ""))
logger.debug(f"Parsed JSON result from intermediary for '{tool_name}': {parsed_result}")
return parsed_result
except json.JSONDecodeError as e:
logger.error(
f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {mcp_response.content[0].text}. Error: {e}"
f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {getattr(mcp_response.content[0], 'text', '')}. Error: {e}"
)
raise RuntimeError(f"Failed to parse JSON response from intermediary tool '{tool_name}'.")

Expand Down
14 changes: 11 additions & 3 deletions eval_protocol/rewards/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def extract_code_blocks(text: str, language: Optional[str] = None) -> List[Dict[
List of dictionaries with "code" and "language" keys
"""
pattern = r"```(\w*)\n([\s\S]*?)\n```"
matches = re.findall(pattern, text)
matches = re.findall(pattern, text or "")

code_blocks = []
verbose_patterns_removed = []
Expand Down Expand Up @@ -1098,7 +1098,15 @@ def fractional_code_reward(
},
)

code_blocks = extract_code_blocks(response_content, language)
# Normalize content to string; Message.content may be str or list of content parts
_last_content = response_content
response_content_str = (
_last_content
if isinstance(_last_content, str)
else "".join([getattr(p, "text", "") for p in (_last_content or [])])
)

code_blocks = extract_code_blocks(response_content_str, language)

if not code_blocks:
return EvaluateResult(
Expand Down Expand Up @@ -1617,7 +1625,7 @@ class Capturing(list):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
self._stringio.close = lambda x: None
self._stringio.close = lambda: None
return self

def __exit__(self, *args):
Expand Down
7 changes: 6 additions & 1 deletion eval_protocol/rewards/deepcoder_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def deepcoder_code_reward(
is_score_valid=False,
)

assistant_content = messages[-1].content
assistant_content_raw = messages[-1].content
assistant_content = (
assistant_content_raw
if isinstance(assistant_content_raw, str)
else "".join([getattr(p, "text", "") for p in (assistant_content_raw or [])])
)
test_cases = ground_truth

code_blocks = extract_code_blocks(assistant_content, language)
Expand Down
7 changes: 6 additions & 1 deletion eval_protocol/rewards/list_comparison_math_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def list_comparison_math_reward(
},
)

gen_content = messages[-1].content
gen_content_raw = messages[-1].content
gen_content = (
gen_content_raw
if isinstance(gen_content_raw, str)
else "".join([getattr(p, "text", "") for p in (gen_content_raw or [])])
)
orig_content = ground_truth

if not gen_content:
Expand Down
14 changes: 12 additions & 2 deletions eval_protocol/rewards/multiple_choice_math_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ def multiple_choice_math_reward(
if messages and len(messages) > 0:
gen_response_message = messages[-1]
if gen_response_message.role == "assistant":
gen_content = gen_response_message.content or ""
raw_gen_content = gen_response_message.content
gen_content = (
raw_gen_content
if isinstance(raw_gen_content, str)
else "".join([getattr(p, "text", "") for p in (raw_gen_content or [])])
)

if not gen_content:
metrics["error_generated_message"] = MetricResult(
Expand All @@ -152,7 +157,12 @@ def multiple_choice_math_reward(
if ground_truth and len(ground_truth) > 0:
orig_response_message = ground_truth[0]
if orig_response_message.role == "assistant":
orig_content = orig_response_message.content or ""
raw_orig_content = orig_response_message.content
orig_content = (
raw_orig_content
if isinstance(raw_orig_content, str)
else "".join([getattr(p, "text", "") for p in (raw_orig_content or [])])
)

if not orig_content:
metrics["error_original_message"] = MetricResult(
Expand Down
5 changes: 2 additions & 3 deletions eval_protocol/typed_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def decorator(func: F) -> F:

if not has_var_keyword:
raise ValueError(
f"Function '{func.__name__}' must accept **kwargs parameter. "
f"Please add '**kwargs' to the function signature."
f"Function '{func.__name__}' must accept **kwargs parameter. Please add '**kwargs' to the function signature."
)

# Setup resources once when the decorator is applied
Expand Down Expand Up @@ -113,7 +112,7 @@ def _is_list_of_message_annotation(annotation: Any) -> bool:
inner = non_none[0]
inner_origin = get_origin(inner)
inner_args = get_args(inner)
return inner_origin in (list, List) and inner_args and inner_args[0] == Message
return (inner_origin in (list, List)) and bool(inner_args) and (inner_args[0] == Message)
return False

def _prepare_final_args(*args: Any, **kwargs: Any):
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/utils/batch_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def transform_n_variant_jsonl_to_batch_format(
request_id_field: str = "request_id",
response_id_field: str = "response_id",
messages_field: str = "full_conversation_history",
fallback_messages_fields: List[str] = None,
fallback_messages_fields: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Transform N-variant generation JSONL output into batch evaluation format.
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/utils/vite_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import AsyncGenerator, Callable, Optional
from typing import AsyncGenerator, Callable, Optional, Any

import uvicorn
from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(
host: str = "localhost",
port: int = 8000,
index_file: str = "index.html",
lifespan: Optional[Callable[[FastAPI], AsyncGenerator[None, None]]] = None,
lifespan: Optional[Callable[[FastAPI], Any]] = None,
):
self.build_dir = Path(build_dir)
self.host = host
Expand Down
Loading