From cdf92b5b783c854cb29395947fc332bb3bbc5f1d Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 1 Sep 2025 22:05:03 +0800 Subject: [PATCH 1/2] type fix round 7 --- .../test_livebench_data_analysis.py | 11 +-- eval_protocol/datasets/loader.py | 8 +- eval_protocol/integrations/braintrust.py | 13 +-- eval_protocol/mcp/client/connection.py | 11 +-- eval_protocol/mcp/execution/manager.py | 90 ++++++++++++++++--- .../default_mcp_gym_rollout_processor.py | 6 +- eval_protocol/rewards/function_calling.py | 4 +- 7 files changed, 105 insertions(+), 38 deletions(-) diff --git a/eval_protocol/benchmarks/test_livebench_data_analysis.py b/eval_protocol/benchmarks/test_livebench_data_analysis.py index 70e852fd..c6a8efd9 100644 --- a/eval_protocol/benchmarks/test_livebench_data_analysis.py +++ b/eval_protocol/benchmarks/test_livebench_data_analysis.py @@ -424,8 +424,9 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]], - rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], + # Provide a flat list per run (Sequence[InputMessagesParam]) to match signature + input_messages=[[m for m in r.messages] for r in _CTA_ROWS], + rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, @@ -467,8 +468,8 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]], - rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], + input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], + rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS), aggregation_method="mean", passed_threshold=None, @@ -511,7 +512,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], - rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], + rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS), aggregation_method="mean", passed_threshold=None, diff --git a/eval_protocol/datasets/loader.py b/eval_protocol/datasets/loader.py index 487f0f26..ce4c8a25 100644 --- a/eval_protocol/datasets/loader.py +++ b/eval_protocol/datasets/loader.py @@ -97,7 +97,7 @@ def load_and_process_dataset( # preprocessing_steps: Optional[List[str]] = None, # To be implemented hf_extra_load_params: Optional[Dict[str, Any]] = None, **kwargs: Any, # Catch-all for other params -) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]: +) -> Union[Dataset, DatasetDict]: """ Loads a dataset from the specified source. @@ -116,7 +116,8 @@ def load_and_process_dataset( Returns: Loaded dataset, typically as Hugging Face Dataset or DatasetDict. """ - loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]] + # Hugging Face load_dataset always returns Dataset or DatasetDict in our supported modes + loaded_dataset: Union[Dataset, DatasetDict] # Prepare kwargs for datasets.load_dataset, separating out custom ones load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {} @@ -238,9 +239,6 @@ def load_and_process_dataset( for s_name in loaded_dataset.keys(): if len(loaded_dataset[s_name]) > max_samples: loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples)) - elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset - if len(loaded_dataset) > max_samples: - loaded_dataset = loaded_dataset[:max_samples] # Apply column mapping if provided if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)): diff --git a/eval_protocol/integrations/braintrust.py b/eval_protocol/integrations/braintrust.py index 757a2c8a..78bef72d 100644 --- a/eval_protocol/integrations/braintrust.py +++ b/eval_protocol/integrations/braintrust.py @@ -1,6 +1,6 @@ """Adapters for integrating Eval Protocol with Braintrust scoring functions.""" -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, cast from eval_protocol.models import EvaluateResult, Message from eval_protocol.typed_interface import reward_function @@ -17,8 +17,7 @@ def scorer_to_reward_fn( ) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]: """Wrap a Braintrust scorer as an Eval Protocol reward function.""" - @reward_function - def reward_fn( + def reward_fn_core( messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any ) -> EvaluateResult: input_val = messages_to_input(messages) if messages_to_input else messages[0].content @@ -29,9 +28,11 @@ def reward_fn( ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content ) score = scorer(input_val, output_val, expected_val) - return EvaluateResult(score=score) + return EvaluateResult(score=float(score)) - return reward_fn + # Wrap with reward_function decorator while preserving precise callable type for type checker + wrapped = reward_function(reward_fn_core) + return cast(Callable[[List[Message], Optional[List[Message]]], EvaluateResult], wrapped) def reward_fn_to_scorer( @@ -48,6 +49,6 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float: if expected is not None: ground_truth = [Message(role="assistant", content=str(expected))] result = reward_fn(messages=messages, ground_truth=ground_truth) - return result.score + return float(result.score) return scorer diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index a6fcd53d..c239ae97 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -441,9 +441,10 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) # Extract data plane results (observation only) if tool_result.content and len(tool_result.content) > 0: content = tool_result.content[0] - if hasattr(content, "text"): + text_value = getattr(content, "text", None) + if isinstance(text_value, str): # Fix: Handle empty or invalid JSON responses gracefully - if not content.text or content.text.strip() == "": + if text_value.strip() == "": logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}") observation = { "observation": "empty_response", @@ -451,14 +452,14 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) } else: try: - observation = json.loads(content.text) + observation = json.loads(text_value) except json.JSONDecodeError as e: logger.warning( - f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}" + f"Session {session.session_id}: Invalid JSON from {tool_name}: {text_value}. Error: {e}" ) # Create a structured response from the raw text observation = { - "observation": content.text, + "observation": text_value, "session_id": session.session_id, "error": "invalid_json_response", } diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index d6cb2b83..df1ac0a6 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -12,7 +12,7 @@ import threading import time from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast import anyio from openai.types import CompletionUsage @@ -126,7 +126,15 @@ async def _execute_with_semaphore(idx): evaluation_row.messages = messages evaluation_row.tools = shared_tool_schema - evaluation_row.usage = CompletionUsage(**trajectory.usage) + # Some OpenAI SDK versions type CompletionUsage as a TypedDict; construct via cast to avoid ctor mismatches + evaluation_row.usage = cast( + CompletionUsage, + { + "prompt_tokens": trajectory.usage.get("prompt_tokens", 0), + "completion_tokens": trajectory.usage.get("completion_tokens", 0), + "total_tokens": trajectory.usage.get("total_tokens", 0), + }, + ) evaluation_row.input_metadata.completion_params = { "model": policy.model_id, "temperature": getattr(policy, "temperature", None), @@ -138,8 +146,14 @@ async def _execute_with_semaphore(idx): extra_info = None if trajectory.control_plane_summary.get("error_message"): extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + # Convert string termination reason to TerminationReason enum if needed + term_reason = ( + trajectory.termination_reason + if isinstance(trajectory.termination_reason, TerminationReason) + else TerminationReason.from_str(str(trajectory.termination_reason)) + ) evaluation_row.rollout_status = Status.rollout_finished( - termination_reason=trajectory.termination_reason, extra_info=extra_info + termination_reason=term_reason, extra_info=extra_info ) else: evaluation_row.rollout_status = Status.rollout_running() @@ -231,8 +245,9 @@ def extract_text_content(msg_dict): # Get initial messages in tau2-bench format for user simulator user_simulator_state = user_simulator.get_init_state() + # Generate initial user response by prompting the simulator with a user role message user_message, user_simulator_state = await user_simulator.generate_next_message( - AssistantMessage(role="assistant", content="Hi! How can I help you today?"), + UserMessage(role="user", content=""), user_simulator_state, ) current_observation = user_message.content if user_message.content else "" @@ -264,8 +279,11 @@ def extract_text_content(msg_dict): # Last message was agent, simulated user response if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage): # Generate user response using the simulator + # Pass the assistant message content to drive the simulated user's next response + last_assistant = user_simulator_messages[-1] user_message, user_simulator_state = await user_simulator.generate_next_message( - user_simulator_messages[-1], user_simulator_state + last_assistant, + user_simulator_state, ) user_content = user_message.content if user_message.content else "" @@ -285,11 +303,33 @@ def extract_text_content(msg_dict): ) update_evaluation_row_messages() - # calc llm usage stats happened in this turn if there is aany + # Update LLM usage stats if available; support both dict-like and attribute access if usage_stats: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens + try: + prompt_tokens = ( + usage_stats.get("prompt_tokens") + if isinstance(usage_stats, dict) + else usage_stats.prompt_tokens + ) + completion_tokens = ( + usage_stats.get("completion_tokens") + if isinstance(usage_stats, dict) + else usage_stats.completion_tokens + ) + total_tokens = ( + usage_stats.get("total_tokens") + if isinstance(usage_stats, dict) + else usage_stats.total_tokens + ) + if isinstance(prompt_tokens, int): + trajectory.usage["prompt_tokens"] += prompt_tokens + if isinstance(completion_tokens, int): + trajectory.usage["completion_tokens"] += completion_tokens + if isinstance(total_tokens, int): + trajectory.usage["total_tokens"] += total_tokens + except Exception: + # Best-effort; ignore malformed usage stats + pass # If no tool call is generated, turn is finished if len(tool_calls) == 1: @@ -300,7 +340,8 @@ def extract_text_content(msg_dict): # If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed. elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: trajectory.terminated = True - trajectory.termination_reason = TerminationReason.from_str(finish_reason) + # Ensure finish_reason is a string before converting + trajectory.termination_reason = TerminationReason.from_str(str(finish_reason)) break # Execute each tool call sequentially @@ -404,11 +445,32 @@ def extract_text_content(msg_dict): ) update_evaluation_row_messages() if usage_stats: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens + try: + prompt_tokens = ( + usage_stats.get("prompt_tokens") + if isinstance(usage_stats, dict) + else usage_stats.prompt_tokens + ) + completion_tokens = ( + usage_stats.get("completion_tokens") + if isinstance(usage_stats, dict) + else usage_stats.completion_tokens + ) + total_tokens = ( + usage_stats.get("total_tokens") + if isinstance(usage_stats, dict) + else usage_stats.total_tokens + ) + if isinstance(prompt_tokens, int): + trajectory.usage["prompt_tokens"] += prompt_tokens + if isinstance(completion_tokens, int): + trajectory.usage["completion_tokens"] += completion_tokens + if isinstance(total_tokens, int): + trajectory.usage["total_tokens"] += total_tokens + except Exception: + pass trajectory.terminated = True - trajectory.termination_reason = TerminationReason.from_str(finish_reason) + trajectory.termination_reason = TerminationReason.from_str(str(finish_reason)) trajectory.control_plane_summary.update( { "total_reward": trajectory.total_reward, diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 52545cf9..f3cde601 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -107,7 +107,7 @@ def _wait_for_server_ready(self, timeout: int = 15) -> bool: while time.time() - start_time < timeout: # Check if process is still running - if self.process.poll() is not None: + if self.process and self.process.poll() is not None: print("Server process exited early") return False @@ -220,7 +220,9 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> self.server.start() self.policy = ep.LiteLLMPolicy( - model_id=config.completion_params.get("model", None), + model_id=str( + (config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini" + ), temperature=config.completion_params.get("temperature", 0.0), max_tokens=config.completion_params.get("max_tokens", 4096), **(config.completion_params.get("extra_body", {}) or {}), diff --git a/eval_protocol/rewards/function_calling.py b/eval_protocol/rewards/function_calling.py index 7b1d60e4..14e8bfe8 100644 --- a/eval_protocol/rewards/function_calling.py +++ b/eval_protocol/rewards/function_calling.py @@ -394,9 +394,11 @@ def exact_tool_match_reward( try: ground_truth = json.loads(ground_truth) except json.JSONDecodeError: + # Cast to string before slicing to satisfy type checker if ground_truth is of unknown type + gt_preview = str(ground_truth) return EvaluateResult( score=0.0, - reason=f"Ground truth was a string but failed to parse as JSON: {ground_truth[:100]}...", + reason=f"Ground truth was a string but failed to parse as JSON: {gt_preview[:100]}...", metrics={}, ) From c0e3ed38c2a7b3ee462619c91efc58f56674bba6 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 1 Sep 2025 22:20:30 +0800 Subject: [PATCH 2/2] fix a few more --- .../benchmarks/test_livebench_data_analysis.py | 8 ++++---- eval_protocol/benchmarks/test_tau_bench_airline.py | 2 +- eval_protocol/benchmarks/test_tau_bench_retail.py | 2 +- eval_protocol/execution/pipeline.py | 7 +++++++ eval_protocol/integrations/braintrust.py | 2 +- eval_protocol/mcp/execution/manager.py | 9 ++++++++- eval_protocol/mcp/simulation_server.py | 2 +- .../pytest/default_langchain_rollout_processor.py | 12 ++++++++++-- eval_protocol/rewards/function_calling.py | 6 +++--- eval_protocol/rewards/json_schema.py | 4 ++-- eval_protocol/rewards/lean_prover.py | 2 +- 11 files changed, 39 insertions(+), 17 deletions(-) diff --git a/eval_protocol/benchmarks/test_livebench_data_analysis.py b/eval_protocol/benchmarks/test_livebench_data_analysis.py index c6a8efd9..8c8c5e3c 100644 --- a/eval_protocol/benchmarks/test_livebench_data_analysis.py +++ b/eval_protocol/benchmarks/test_livebench_data_analysis.py @@ -424,8 +424,8 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - # Provide a flat list per run (Sequence[InputMessagesParam]) to match signature - input_messages=[[m for m in r.messages] for r in _CTA_ROWS], + # Wrap dataset messages in an extra list to match Sequence[list[InputMessagesParam]] + input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]], rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", @@ -468,7 +468,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], + input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]], rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS), aggregation_method="mean", @@ -511,7 +511,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], + input_messages=[[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS]], rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}}, rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS), aggregation_method="mean", diff --git a/eval_protocol/benchmarks/test_tau_bench_airline.py b/eval_protocol/benchmarks/test_tau_bench_airline.py index 24417ddc..1cd3149e 100644 --- a/eval_protocol/benchmarks/test_tau_bench_airline.py +++ b/eval_protocol/benchmarks/test_tau_bench_airline.py @@ -147,7 +147,7 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow: messages = row.messages # Get evaluation criteria and user_simulation from input_metadata.dataset_info - dataset_info = row.input_metadata.dataset_info if row.input_metadata else {} + dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {} evaluation_criteria = dataset_info.get("evaluation_criteria", {}) nl_assertions = evaluation_criteria.get("nl_assertions", []) diff --git a/eval_protocol/benchmarks/test_tau_bench_retail.py b/eval_protocol/benchmarks/test_tau_bench_retail.py index d26d2675..a3faa2b9 100644 --- a/eval_protocol/benchmarks/test_tau_bench_retail.py +++ b/eval_protocol/benchmarks/test_tau_bench_retail.py @@ -137,7 +137,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow: messages = row.messages # Get evaluation criteria and user_simulation from input_metadata.dataset_info - dataset_info = row.input_metadata.dataset_info if row.input_metadata else {} + dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {} evaluation_criteria = dataset_info.get("evaluation_criteria", {}) nl_assertions = evaluation_criteria.get("nl_assertions", []) diff --git a/eval_protocol/execution/pipeline.py b/eval_protocol/execution/pipeline.py index 8830f3ff..80e25463 100644 --- a/eval_protocol/execution/pipeline.py +++ b/eval_protocol/execution/pipeline.py @@ -87,6 +87,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str) try: backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}] + assert self.mcp_intermediary_client is not None init_response = await self.mcp_intermediary_client.initialize_session(backend_requests) if init_response.get("error"): @@ -109,6 +110,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str) current_instance_id = inst_info_dict.get("instance_id") if not current_instance_id: continue + assert self.mcp_intermediary_client is not None list_tools_result = await self.mcp_intermediary_client.list_backend_tools( rk_session_id=rk_session_id, instance_id=current_instance_id, @@ -130,6 +132,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str) if rk_session_id and self.mcp_intermediary_client: logger.info(f"Sample {sample_id}: Cleaning up tool discovery session '{rk_session_id}'.") try: + assert self.mcp_intermediary_client is not None await self.mcp_intermediary_client.cleanup_session(rk_session_id) except Exception as e_cl: logger.error( @@ -276,6 +279,7 @@ async def _execute_mcp_agent_rollout( try: backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}] + assert self.mcp_intermediary_client is not None init_response = await self.mcp_intermediary_client.initialize_session(backend_requests) if init_response.get("error"): raise RuntimeError( @@ -331,6 +335,7 @@ async def _execute_mcp_agent_rollout( if not isinstance(tool_args_dict, dict): raise ValueError("Args not dict") + assert self.mcp_intermediary_client is not None exec_result = await self.mcp_intermediary_client.call_backend_tool( rk_session_id=rk_session_id, instance_id=primary_instance_id_for_agent_actions, @@ -405,6 +410,7 @@ async def _execute_mcp_agent_rollout( state_capture_tool = self.cfg.agent.get("state_capture_tool") if state_capture_tool: state_capture_args = dict(self.cfg.agent.get("state_capture_args", OmegaConf.create({}))) + assert self.mcp_intermediary_client is not None final_filesystem_state_from_mcp = await self.mcp_intermediary_client.call_backend_tool( rk_session_id=rk_session_id, instance_id=primary_instance_id_for_agent_actions, @@ -432,6 +438,7 @@ async def _execute_mcp_agent_rollout( } finally: if rk_session_id and self.mcp_intermediary_client: + assert self.mcp_intermediary_client is not None await self.mcp_intermediary_client.cleanup_session(rk_session_id) async def _process_single_sample( diff --git a/eval_protocol/integrations/braintrust.py b/eval_protocol/integrations/braintrust.py index 78bef72d..14080bcb 100644 --- a/eval_protocol/integrations/braintrust.py +++ b/eval_protocol/integrations/braintrust.py @@ -48,7 +48,7 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float: ground_truth = None if expected is not None: ground_truth = [Message(role="assistant", content=str(expected))] - result = reward_fn(messages=messages, ground_truth=ground_truth) + result = reward_fn(messages, ground_truth) return float(result.score) return scorer diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index df1ac0a6..753d454e 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -281,8 +281,15 @@ def extract_text_content(msg_dict): # Generate user response using the simulator # Pass the assistant message content to drive the simulated user's next response last_assistant = user_simulator_messages[-1] + # Convert last assistant message into a valid user input message for simulator + from vendor.tau2.data_model.message import UserMessage as TauUserMessage + + converted_user_prompt = ( + last_assistant.content if getattr(last_assistant, "content", None) else "" + ) + converted_message = TauUserMessage(role="user", content=converted_user_prompt) user_message, user_simulator_state = await user_simulator.generate_next_message( - last_assistant, + converted_message, user_simulator_state, ) user_content = user_message.content if user_message.content else "" diff --git a/eval_protocol/mcp/simulation_server.py b/eval_protocol/mcp/simulation_server.py index 801ad0d4..1e5f9ff4 100644 --- a/eval_protocol/mcp/simulation_server.py +++ b/eval_protocol/mcp/simulation_server.py @@ -288,7 +288,7 @@ def _discover_and_register_resources(self): if discovered_resources: @self.app.read_resource() - async def read_resource(uri: str): + async def read_resource(uri: AnyUrl): # Get the current request context ctx = self.app.request_context diff --git a/eval_protocol/pytest/default_langchain_rollout_processor.py b/eval_protocol/pytest/default_langchain_rollout_processor.py index 7d0321d3..35924570 100644 --- a/eval_protocol/pytest/default_langchain_rollout_processor.py +++ b/eval_protocol/pytest/default_langchain_rollout_processor.py @@ -56,9 +56,17 @@ def __init__(self, content: str): # Resolve the appropriate async invoke function if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"): - invoke_fn = target.graph.ainvoke + + async def _invoke_graph(payload): + return await target.graph.ainvoke(payload) # type: ignore[attr-defined] + + invoke_fn = _invoke_graph elif hasattr(target, "ainvoke"): - invoke_fn = target.ainvoke + + async def _invoke_direct(payload): + return await target.ainvoke(payload) # type: ignore[attr-defined] + + invoke_fn = _invoke_direct elif callable(target): async def _invoke_wrapper(payload): diff --git a/eval_protocol/rewards/function_calling.py b/eval_protocol/rewards/function_calling.py index 14e8bfe8..aa0517b6 100644 --- a/eval_protocol/rewards/function_calling.py +++ b/eval_protocol/rewards/function_calling.py @@ -451,7 +451,7 @@ def schema_jaccard_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) + return exact_tool_match_reward(messages, ground_truth, **kwargs) @reward_function @@ -493,7 +493,7 @@ def llm_judge_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) + return exact_tool_match_reward(messages, ground_truth, **kwargs) @reward_function @@ -537,7 +537,7 @@ def composite_function_call_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) + return exact_tool_match_reward(messages, ground_truth, **kwargs) # JSON schema reward functions have been moved to json_schema.py module diff --git a/eval_protocol/rewards/json_schema.py b/eval_protocol/rewards/json_schema.py index eccfa7f0..e48c1160 100644 --- a/eval_protocol/rewards/json_schema.py +++ b/eval_protocol/rewards/json_schema.py @@ -290,8 +290,8 @@ def json_schema_reward_with_llm_judge( normalized_weights = {k: v / total_weight for k, v in weights.items()} schema_result = json_schema_reward( - messages=messages, - ground_truth=ground_truth, + messages, + ground_truth, json_content=json_content, expected_schema=expected_schema, **kwargs, diff --git a/eval_protocol/rewards/lean_prover.py b/eval_protocol/rewards/lean_prover.py index f134fcfe..347cf6a4 100644 --- a/eval_protocol/rewards/lean_prover.py +++ b/eval_protocol/rewards/lean_prover.py @@ -417,7 +417,7 @@ def deepseek_huggingface_prover_benchmark( expected_proof = expected_proof_from_gt reference_solution = None if dataset_item: - if not expected_proof: + if not expected_proof and dataset_item is not None: expected_proof = dataset_item.get("expected_proof", None) reference_solution = dataset_item.get("reference_solution", None) proof_reference = expected_proof or reference_solution