Skip to content

Commit 1ae92c1

Browse files
Benny ChenBenny Chen
authored andcommitted
fix langchain and properly fix messages
1 parent 2834bc9 commit 1ae92c1

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

eval_protocol/models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,19 @@ class Message(BaseModel):
243243

244244
@classmethod
245245
def model_validate(cls, obj, *args, **kwargs):
246-
if isinstance(obj, dict) and "role" not in obj:
247-
raise ValueError("Role is required")
246+
if isinstance(obj, dict):
247+
if "role" not in obj:
248+
raise ValueError("Role is required")
249+
# Be lenient: if tool_calls entries are missing required 'id', synthesize one
250+
tool_calls_obj = obj.get("tool_calls")
251+
if isinstance(tool_calls_obj, list):
252+
fixed_tool_calls = []
253+
for tc in tool_calls_obj:
254+
if isinstance(tc, dict):
255+
if not tc.get("id"):
256+
tc = {**tc, "id": generate_id()}
257+
fixed_tool_calls.append(tc)
258+
obj = {**obj, "tool_calls": fixed_tool_calls}
248259
return super().model_validate(obj, *args, **kwargs)
249260

250261

eval_protocol/pytest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
1010

11-
# Conditional import for optional dependency
11+
# Conditional import for optional dependencies
1212
try:
1313
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
1414

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import asyncio
22
from typing import List
33

4-
from langchain_core.messages import BaseMessage
4+
try:
5+
from langchain_core.messages import BaseMessage
6+
except Exception: # pragma: no cover - optional dependency path
7+
# Minimal fallback base type to satisfy typing when langchain is not present
8+
class BaseMessage: # type: ignore
9+
pass
10+
511

612
from eval_protocol.models import EvaluationRow, Message
713
from eval_protocol.pytest.rollout_processor import RolloutProcessor
@@ -25,7 +31,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
2531

2632
async def _process_row(row: EvaluationRow) -> EvaluationRow:
2733
# Build LC messages from EP row
28-
from langchain_core.messages import HumanMessage
34+
try:
35+
from langchain_core.messages import HumanMessage
36+
except Exception:
37+
# Fallback minimal message if langchain_core is unavailable
38+
class HumanMessage: # type: ignore
39+
def __init__(self, content: str):
40+
self.content = content
2941

3042
lm_messages: List[BaseMessage] = []
3143
if row.messages:

eval_protocol/rewards/accuracy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,20 @@ def accuracy_reward(
436436
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid GT message type.")},
437437
)
438438

439+
# If ground truth content is empty after coercion, short-circuit with a clear reason
440+
if ground_truth_comparison_text.strip() == "":
441+
return EvaluateResult(
442+
score=0.0,
443+
reason="Ground truth has no content.",
444+
metrics={
445+
"accuracy": MetricResult(
446+
score=0.0,
447+
is_score_valid=False,
448+
reason="Ground truth has no content.",
449+
)
450+
},
451+
)
452+
439453
extracted_answer = extract_fn(model_response_text) if extract_fn else extract_math_expression(model_response_text)
440454
if (
441455
not extracted_answer

eval_protocol/typed_interface.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
119119
if isinstance(item_data, Message):
120120
typed_list.append(item_data)
121121
elif isinstance(item_data, dict):
122-
typed_list.append(Message(**item_data))
122+
typed_list.append(Message.model_validate(item_data))
123123
else:
124124
raise TypeError(f"Unexpected type for item {i} in '{arg_name_for_error}': {type(item_data)}")
125125
return typed_list
@@ -134,8 +134,9 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
134134
):
135135
try:
136136
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
137-
except Exception as err:
138-
raise ValueError(f"Input 'messages' failed Pydantic validation: {err}") from None
137+
except Exception:
138+
# Be lenient: leave messages as-is if coercion fails (backward compatibility)
139+
pass
139140

140141
elif mode == "batch" and "rollouts_messages" in params and "rollouts_messages" in final_func_args:
141142
param_annotation = params["rollouts_messages"].annotation
@@ -157,14 +158,26 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
157158
gt_ann = params["ground_truth"].annotation
158159
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
159160
if final_func_args["ground_truth"] is not None:
160-
try:
161-
final_func_args["ground_truth"] = _coerce_to_list_message(
162-
final_func_args["ground_truth"], "ground_truth"
163-
)
164-
except Exception as err:
165-
raise ValueError(
166-
f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"
167-
) from None
161+
# Accept flexible ground_truth inputs: list, dict, or str
162+
gt_val = final_func_args["ground_truth"]
163+
if isinstance(gt_val, list):
164+
try:
165+
final_func_args["ground_truth"] = _coerce_to_list_message(gt_val, "ground_truth")
166+
except Exception:
167+
# Leave as-is if strict coercion fails
168+
pass
169+
elif isinstance(gt_val, dict):
170+
try:
171+
final_func_args["ground_truth"] = _coerce_to_list_message([gt_val], "ground_truth")
172+
except Exception:
173+
pass
174+
elif isinstance(gt_val, str):
175+
try:
176+
final_func_args["ground_truth"] = _coerce_to_list_message(
177+
[{"role": "system", "content": gt_val}], "ground_truth"
178+
)
179+
except Exception:
180+
pass
168181

169182
# Inject resource clients into kwargs (resources are already setup)
170183
if resource_managers:

0 commit comments

Comments
 (0)