Skip to content

Commit 24cca19

Browse files
benjibcBenny Chen
andauthored
relax openai dependency and LangGraph support, also reasoning_content support (#118)
* relax openai dependency * langgraph support * rebuild vite app * fix ruff --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent cb42aed commit 24cca19

File tree

15 files changed

+313
-44
lines changed

15 files changed

+313
-44
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import List
5+
6+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
7+
8+
from eval_protocol.models import Message
9+
10+
11+
def _dbg_enabled() -> bool:
12+
return os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1"
13+
14+
15+
def _dbg_print(*args):
16+
if _dbg_enabled():
17+
try:
18+
print(*args)
19+
except Exception:
20+
pass
21+
22+
23+
def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
24+
_dbg_print(
25+
"[EP-Ser] Input LC msg:",
26+
type(msg).__name__,
27+
{
28+
"has_additional_kwargs": isinstance(getattr(msg, "additional_kwargs", None), dict),
29+
"content_type": type(getattr(msg, "content", None)).__name__,
30+
},
31+
)
32+
33+
if isinstance(msg, HumanMessage):
34+
ep_msg = Message(role="user", content=str(msg.content))
35+
_dbg_print("[EP-Ser] -> EP Message:", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
36+
return ep_msg
37+
38+
if isinstance(msg, AIMessage):
39+
content = ""
40+
if isinstance(msg.content, str):
41+
content = msg.content
42+
elif isinstance(msg.content, list):
43+
parts: List[str] = []
44+
for item in msg.content:
45+
if isinstance(item, dict):
46+
if item.get("type") == "text":
47+
parts.append(str(item.get("text", "")))
48+
elif isinstance(item, str):
49+
parts.append(item)
50+
content = "\n".join(parts)
51+
52+
tool_calls_payload = None
53+
54+
def _normalize_tool_calls(tc_list: list) -> list[dict]:
55+
mapped: List[dict] = []
56+
for call in tc_list:
57+
if not isinstance(call, dict):
58+
continue
59+
try:
60+
call_id = call.get("id") or "toolcall_0"
61+
if isinstance(call.get("function"), dict):
62+
fn = call["function"]
63+
fn_name = fn.get("name") or call.get("name") or "tool"
64+
fn_args = fn.get("arguments")
65+
else:
66+
fn_name = call.get("name") or "tool"
67+
fn_args = call.get("arguments") if call.get("arguments") is not None else call.get("args")
68+
if not isinstance(fn_args, str):
69+
import json as _json
70+
71+
fn_args = _json.dumps(fn_args or {}, ensure_ascii=False)
72+
mapped.append(
73+
{
74+
"id": call_id,
75+
"type": "function",
76+
"function": {"name": fn_name, "arguments": fn_args},
77+
}
78+
)
79+
except Exception:
80+
continue
81+
return mapped
82+
83+
ak = getattr(msg, "additional_kwargs", None)
84+
if isinstance(ak, dict):
85+
tc = ak.get("tool_calls")
86+
if isinstance(tc, list) and tc:
87+
mapped = _normalize_tool_calls(tc)
88+
if mapped:
89+
tool_calls_payload = mapped
90+
91+
if tool_calls_payload is None:
92+
raw_attr_tc = getattr(msg, "tool_calls", None)
93+
if isinstance(raw_attr_tc, list) and raw_attr_tc:
94+
mapped = _normalize_tool_calls(raw_attr_tc)
95+
if mapped:
96+
tool_calls_payload = mapped
97+
98+
# Extract reasoning/thinking parts into reasoning_content
99+
reasoning_content = None
100+
if isinstance(msg.content, list):
101+
collected = [
102+
it.get("thinking", "") for it in msg.content if isinstance(it, dict) and it.get("type") == "thinking"
103+
]
104+
if collected:
105+
reasoning_content = "\n\n".join([s for s in collected if s]) or None
106+
107+
ep_msg = Message(
108+
role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content
109+
)
110+
_dbg_print(
111+
"[EP-Ser] -> EP Message:",
112+
{
113+
"role": ep_msg.role,
114+
"content_len": len(ep_msg.content or ""),
115+
"tool_calls": len(ep_msg.tool_calls or []) if isinstance(ep_msg.tool_calls, list) else 0,
116+
},
117+
)
118+
return ep_msg
119+
120+
if isinstance(msg, ToolMessage):
121+
tool_name = msg.name or "tool"
122+
status = msg.status or "success"
123+
content = str(msg.content)
124+
tool_call_id = getattr(msg, "tool_call_id", None)
125+
ep_msg = Message(
126+
role="tool",
127+
name=tool_name,
128+
tool_call_id=tool_call_id,
129+
content=f'<{tool_name} status="{status}">\n{content}\n</{tool_name}>',
130+
)
131+
_dbg_print(
132+
"[EP-Ser] -> EP Message:", {"role": ep_msg.role, "name": ep_msg.name, "has_id": bool(ep_msg.tool_call_id)}
133+
)
134+
return ep_msg
135+
136+
ep_msg = Message(role=getattr(msg, "type", "assistant"), content=str(getattr(msg, "content", "")))
137+
_dbg_print("[EP-Ser] -> EP Message (fallback):", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
138+
return ep_msg

eval_protocol/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ class Message(BaseModel):
230230
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
231231
default="", description="The content of the message."
232232
)
233+
reasoning_content: Optional[str] = Field(
234+
default=None, description="Optional hidden chain-of-thought or reasoning content."
235+
)
233236
name: Optional[str] = None
234237
tool_call_id: Optional[str] = None
235238
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
6+
from .default_langchain_rollout_processor import LangGraphRolloutProcessor
67
from .evaluation_test import evaluation_test
78
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
89
from .rollout_processor import RolloutProcessor
@@ -22,6 +23,7 @@
2223
"MCPGymRolloutProcessor",
2324
"RolloutProcessor",
2425
"SingleTurnRolloutProcessor",
26+
"LangGraphRolloutProcessor",
2527
"NoOpRolloutProcessor",
2628
"default_dataset_adapter",
2729
"RolloutProcessorConfig",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import asyncio
2+
from typing import List
3+
4+
from langchain_core.messages import BaseMessage
5+
6+
from eval_protocol.models import EvaluationRow, Message
7+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
9+
10+
11+
class LangGraphRolloutProcessor(RolloutProcessor):
12+
"""Generic rollout processor for LangChain agents.
13+
14+
Accepts an async factory that returns a target to invoke. The target can be:
15+
- An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph)
16+
- An object with `.ainvoke(payload)`
17+
- A callable that accepts `payload` and returns the result dict
18+
"""
19+
20+
def __init__(self, get_invoke_target):
21+
self.get_invoke_target = get_invoke_target
22+
23+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
24+
tasks: List[asyncio.Task] = []
25+
26+
async def _process_row(row: EvaluationRow) -> EvaluationRow:
27+
# Build LC messages from EP row
28+
from langchain_core.messages import HumanMessage
29+
30+
lm_messages: List[BaseMessage] = []
31+
if row.messages:
32+
last_user = [m for m in row.messages if m.role == "user"]
33+
if last_user:
34+
lm_messages.append(HumanMessage(content=last_user[-1].content or ""))
35+
if not lm_messages:
36+
lm_messages = [HumanMessage(content="")] # minimal
37+
38+
target = await self.get_invoke_target(config)
39+
40+
# Resolve the appropriate async invoke function
41+
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
42+
invoke_fn = target.graph.ainvoke
43+
elif hasattr(target, "ainvoke"):
44+
invoke_fn = target.ainvoke
45+
elif callable(target):
46+
47+
async def _invoke_wrapper(payload):
48+
return await target(payload)
49+
50+
invoke_fn = _invoke_wrapper
51+
else:
52+
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
53+
54+
result = await invoke_fn({"messages": lm_messages})
55+
result_messages: List[BaseMessage] = result.get("messages", [])
56+
57+
def _serialize_message(msg: BaseMessage) -> Message:
58+
# Prefer SDK-level serializer
59+
try:
60+
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser
61+
62+
return _ser(msg)
63+
except Exception:
64+
# Minimal fallback: best-effort string content only
65+
content = getattr(msg, "content", "")
66+
return Message(role=getattr(msg, "type", "assistant"), content=str(content))
67+
68+
row.messages = [_serialize_message(m) for m in result_messages]
69+
return row
70+
71+
for r in rows:
72+
tasks.append(asyncio.create_task(_process_row(r)))
73+
74+
return tasks
75+
76+
def cleanup(self) -> None:
77+
return None

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List
66

77
from litellm import acompletion
8-
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
8+
from typing import Dict
99

1010
from eval_protocol.dataset_logger import default_logger
1111
from eval_protocol.models import EvaluationRow, Message
@@ -71,17 +71,34 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7171

7272
converted_tool_calls = None
7373
if tool_calls:
74-
converted_tool_calls = [
75-
ChatCompletionMessageToolCall(
76-
id=tool_call.id,
77-
type=tool_call.type,
78-
function={
79-
"name": tool_call.function.name,
80-
"arguments": tool_call.function.arguments,
81-
},
82-
)
83-
for tool_call in tool_calls
84-
]
74+
converted_tool_calls = []
75+
for tool_call in tool_calls:
76+
try:
77+
converted_tool_calls.append(
78+
{
79+
"id": tool_call.id,
80+
"type": tool_call.type,
81+
"function": {
82+
"name": tool_call.function.name,
83+
"arguments": tool_call.function.arguments,
84+
},
85+
}
86+
)
87+
except Exception:
88+
# best-effort: fallback to dict form
89+
try:
90+
converted_tool_calls.append(
91+
{
92+
"id": getattr(tool_call, "id", "toolcall_0"),
93+
"type": getattr(tool_call, "type", "function"),
94+
"function": {
95+
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
96+
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
97+
},
98+
}
99+
)
100+
except Exception:
101+
pass
85102

86103
messages = list(row.messages) + [
87104
Message(

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,24 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
445445
)
446446
else:
447447
r.eval_metadata.status = Status.eval_finished()
448+
# Optional debug print for assistant/tool sequence
449+
if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1":
450+
try:
451+
preview = [
452+
{
453+
"role": m.role,
454+
"len": len(m.content or "") if isinstance(m.content, str) else None,
455+
"tool_calls": len(m.tool_calls or [])
456+
if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list)
457+
else 0,
458+
"tool_call_id": getattr(m, "tool_call_id", None),
459+
"name": getattr(m, "name", None),
460+
}
461+
for m in r.messages
462+
]
463+
print("[EP-Log] Row messages:", preview)
464+
except Exception:
465+
pass
448466
active_logger.log(r)
449467

450468
# if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"dataclasses-json>=0.5.7",
2323
"uvicorn>=0.15.0",
2424
"python-dotenv>=0.19.0",
25-
"openai==1.78.1",
25+
"openai>=1.78.1",
2626
"aiosqlite",
2727
"aiohttp",
2828
"mcp>=1.9.2",
@@ -71,7 +71,7 @@ dev = [
7171
"types-PyYAML",
7272
"types-docker",
7373
"versioneer>=0.20",
74-
"openai==1.78.1",
74+
"openai>=1.78.1",
7575
"pre-commit",
7676
"e2b",
7777
"pytest-cov",

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)