Skip to content

Commit 630804b

Browse files
benjibcBenny Chen
authored andcommitted
langgraph support
1 parent 8952208 commit 630804b

File tree

9 files changed

+264
-12
lines changed

9 files changed

+264
-12
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import List
5+
6+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
7+
8+
from eval_protocol.models import Message
9+
10+
11+
def _dbg_enabled() -> bool:
12+
return os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1"
13+
14+
15+
def _dbg_print(*args):
16+
if _dbg_enabled():
17+
try:
18+
print(*args)
19+
except Exception:
20+
pass
21+
22+
23+
def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
24+
_dbg_print("[EP-Ser] Input LC msg:", type(msg).__name__, {
25+
"has_additional_kwargs": isinstance(getattr(msg, "additional_kwargs", None), dict),
26+
"content_type": type(getattr(msg, "content", None)).__name__,
27+
})
28+
29+
if isinstance(msg, HumanMessage):
30+
ep_msg = Message(role="user", content=str(msg.content))
31+
_dbg_print("[EP-Ser] -> EP Message:", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
32+
return ep_msg
33+
34+
if isinstance(msg, AIMessage):
35+
content = ""
36+
if isinstance(msg.content, str):
37+
content = msg.content
38+
elif isinstance(msg.content, list):
39+
parts: List[str] = []
40+
for item in msg.content:
41+
if isinstance(item, dict):
42+
if item.get("type") == "text":
43+
parts.append(str(item.get("text", "")))
44+
elif isinstance(item, str):
45+
parts.append(item)
46+
content = "\n".join(parts)
47+
48+
tool_calls_payload = None
49+
50+
def _normalize_tool_calls(tc_list: list) -> list[dict]:
51+
mapped: List[dict] = []
52+
for call in tc_list:
53+
if not isinstance(call, dict):
54+
continue
55+
try:
56+
call_id = call.get("id") or "toolcall_0"
57+
if isinstance(call.get("function"), dict):
58+
fn = call["function"]
59+
fn_name = fn.get("name") or call.get("name") or "tool"
60+
fn_args = fn.get("arguments")
61+
else:
62+
fn_name = call.get("name") or "tool"
63+
fn_args = call.get("arguments") if call.get("arguments") is not None else call.get("args")
64+
if not isinstance(fn_args, str):
65+
import json as _json
66+
67+
fn_args = _json.dumps(fn_args or {}, ensure_ascii=False)
68+
mapped.append({
69+
"id": call_id,
70+
"type": "function",
71+
"function": {"name": fn_name, "arguments": fn_args},
72+
})
73+
except Exception:
74+
continue
75+
return mapped
76+
77+
ak = getattr(msg, "additional_kwargs", None)
78+
if isinstance(ak, dict):
79+
tc = ak.get("tool_calls")
80+
if isinstance(tc, list) and tc:
81+
mapped = _normalize_tool_calls(tc)
82+
if mapped:
83+
tool_calls_payload = mapped
84+
85+
if tool_calls_payload is None:
86+
raw_attr_tc = getattr(msg, "tool_calls", None)
87+
if isinstance(raw_attr_tc, list) and raw_attr_tc:
88+
mapped = _normalize_tool_calls(raw_attr_tc)
89+
if mapped:
90+
tool_calls_payload = mapped
91+
92+
# Extract reasoning/thinking parts into reasoning_content
93+
reasoning_content = None
94+
if isinstance(msg.content, list):
95+
collected = [it.get("thinking", "") for it in msg.content if isinstance(it, dict) and it.get("type") == "thinking"]
96+
if collected:
97+
reasoning_content = "\n\n".join([s for s in collected if s]) or None
98+
99+
ep_msg = Message(role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content)
100+
_dbg_print("[EP-Ser] -> EP Message:", {
101+
"role": ep_msg.role,
102+
"content_len": len(ep_msg.content or ""),
103+
"tool_calls": len(ep_msg.tool_calls or []) if isinstance(ep_msg.tool_calls, list) else 0,
104+
})
105+
return ep_msg
106+
107+
if isinstance(msg, ToolMessage):
108+
tool_name = msg.name or "tool"
109+
status = msg.status or "success"
110+
content = str(msg.content)
111+
tool_call_id = getattr(msg, "tool_call_id", None)
112+
ep_msg = Message(
113+
role="tool",
114+
name=tool_name,
115+
tool_call_id=tool_call_id,
116+
content=f"<{tool_name} status=\"{status}\">\n{content}\n</{tool_name}>",
117+
)
118+
_dbg_print("[EP-Ser] -> EP Message:", {"role": ep_msg.role, "name": ep_msg.name, "has_id": bool(ep_msg.tool_call_id)})
119+
return ep_msg
120+
121+
ep_msg = Message(role=getattr(msg, "type", "assistant"), content=str(getattr(msg, "content", "")))
122+
_dbg_print("[EP-Ser] -> EP Message (fallback):", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
123+
return ep_msg
124+
125+

eval_protocol/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,18 @@ class ChatCompletionContentPartTextParam(BaseModel):
223223
type: Literal["text"] = Field("text", description="The type of the content part.")
224224

225225

226+
227+
226228
class Message(BaseModel):
227229
"""Chat message model with trajectory evaluation support."""
228230

229231
role: str # assistant, user, system, tool
230232
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
231233
default="", description="The content of the message."
232234
)
235+
reasoning_content: Optional[str] = Field(
236+
default=None, description="Optional hidden chain-of-thought or reasoning content."
237+
)
233238
name: Optional[str] = None
234239
tool_call_id: Optional[str] = None
235240
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
6+
from .default_langchain_rollout_processor import LangGraphRolloutProcessor
67
from .evaluation_test import evaluation_test
78
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
89
from .rollout_processor import RolloutProcessor
@@ -22,6 +23,7 @@
2223
"MCPGymRolloutProcessor",
2324
"RolloutProcessor",
2425
"SingleTurnRolloutProcessor",
26+
"LangGraphRolloutProcessor",
2527
"NoOpRolloutProcessor",
2628
"default_dataset_adapter",
2729
"RolloutProcessorConfig",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import asyncio
2+
from typing import List
3+
4+
from langchain_core.messages import BaseMessage
5+
6+
from eval_protocol.models import EvaluationRow, Message
7+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
9+
10+
11+
class LangGraphRolloutProcessor(RolloutProcessor):
12+
"""Generic rollout processor for LangChain agents.
13+
14+
Accepts an async factory that returns a target to invoke. The target can be:
15+
- An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph)
16+
- An object with `.ainvoke(payload)`
17+
- A callable that accepts `payload` and returns the result dict
18+
"""
19+
20+
def __init__(self, get_invoke_target):
21+
self.get_invoke_target = get_invoke_target
22+
23+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
24+
tasks: List[asyncio.Task] = []
25+
26+
async def _process_row(row: EvaluationRow) -> EvaluationRow:
27+
# Build LC messages from EP row
28+
from langchain_core.messages import HumanMessage
29+
lm_messages: List[BaseMessage] = []
30+
if row.messages:
31+
last_user = [m for m in row.messages if m.role == "user"]
32+
if last_user:
33+
lm_messages.append(HumanMessage(content=last_user[-1].content or ""))
34+
if not lm_messages:
35+
lm_messages = [HumanMessage(content="")] # minimal
36+
37+
target = await self.get_invoke_target(config)
38+
39+
# Resolve the appropriate async invoke function
40+
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
41+
invoke_fn = target.graph.ainvoke
42+
elif hasattr(target, "ainvoke"):
43+
invoke_fn = target.ainvoke
44+
elif callable(target):
45+
async def _invoke_wrapper(payload):
46+
return await target(payload)
47+
48+
invoke_fn = _invoke_wrapper
49+
else:
50+
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
51+
52+
result = await invoke_fn({"messages": lm_messages})
53+
result_messages: List[BaseMessage] = result.get("messages", [])
54+
55+
def _serialize_message(msg: BaseMessage) -> Message:
56+
# Prefer SDK-level serializer
57+
try:
58+
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser
59+
return _ser(msg)
60+
except Exception:
61+
# Minimal fallback: best-effort string content only
62+
content = getattr(msg, "content", "")
63+
return Message(role=getattr(msg, "type", "assistant"), content=str(content))
64+
65+
row.messages = [_serialize_message(m) for m in result_messages]
66+
return row
67+
68+
for r in rows:
69+
tasks.append(asyncio.create_task(_process_row(r)))
70+
71+
return tasks
72+
73+
def cleanup(self) -> None:
74+
return None
75+
76+

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List
66

77
from litellm import acompletion
8-
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
8+
from typing import Dict
99

1010
from eval_protocol.dataset_logger import default_logger
1111
from eval_protocol.models import EvaluationRow, Message
@@ -71,17 +71,30 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7171

7272
converted_tool_calls = None
7373
if tool_calls:
74-
converted_tool_calls = [
75-
ChatCompletionMessageToolCall(
76-
id=tool_call.id,
77-
type=tool_call.type,
78-
function={
79-
"name": tool_call.function.name,
80-
"arguments": tool_call.function.arguments,
81-
},
82-
)
83-
for tool_call in tool_calls
84-
]
74+
converted_tool_calls = []
75+
for tool_call in tool_calls:
76+
try:
77+
converted_tool_calls.append({
78+
"id": tool_call.id,
79+
"type": tool_call.type,
80+
"function": {
81+
"name": tool_call.function.name,
82+
"arguments": tool_call.function.arguments,
83+
},
84+
})
85+
except Exception:
86+
# best-effort: fallback to dict form
87+
try:
88+
converted_tool_calls.append({
89+
"id": getattr(tool_call, "id", "toolcall_0"),
90+
"type": getattr(tool_call, "type", "function"),
91+
"function": {
92+
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
93+
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
94+
},
95+
})
96+
except Exception:
97+
pass
8598

8699
messages = list(row.messages) + [
87100
Message(

eval_protocol/pytest/evaluation_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,22 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
445445
)
446446
else:
447447
r.eval_metadata.status = Status.eval_finished()
448+
# Optional debug print for assistant/tool sequence
449+
if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1":
450+
try:
451+
preview = [
452+
{
453+
"role": m.role,
454+
"len": len(m.content or "") if isinstance(m.content, str) else None,
455+
"tool_calls": len(m.tool_calls or []) if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) else 0,
456+
"tool_call_id": getattr(m, "tool_call_id", None),
457+
"name": getattr(m, "name", None),
458+
}
459+
for m in r.messages
460+
]
461+
print("[EP-Log] Row messages:", preview)
462+
except Exception:
463+
pass
448464
active_logger.log(r)
449465

450466
# if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs

vite-app/dist/assets/index-CgOSTZTF.js.map

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vite-app/src/components/MessageBubble.tsx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export const MessageBubble = ({ message }: { message: Message }) => {
1313
const hasFunctionCall = message.function_call;
1414

1515
// Get the message content as a string
16+
const reasoning = (message as any).reasoning_content as string | undefined;
1617
const getMessageContent = () => {
1718
if (typeof message.content === "string") {
1819
return message.content;
@@ -104,6 +105,17 @@ export const MessageBubble = ({ message }: { message: Message }) => {
104105
{isExpanded ? "Show less" : "Show more"}
105106
</button>
106107
)}
108+
{reasoning && reasoning.trim().length > 0 && (
109+
<div className={`mt-2 pt-1 border-t ${isTool ? "border-green-200" : "border-yellow-200"}`}>
110+
<div className={`font-semibold text-xs mb-0.5 ${isTool ? "text-green-700" : "text-yellow-700"}`}>
111+
Thinking:
112+
</div>
113+
<details className="mb-1">
114+
<summary className={`cursor-pointer text-xs ${isTool ? "text-green-700" : "text-yellow-700"}`}>Show reasoning</summary>
115+
<pre className={`mt-1 p-1 border rounded text-xs whitespace-pre-wrap break-words ${isTool ? "bg-green-100 border-green-200 text-green-800" : "bg-yellow-100 border-yellow-200 text-yellow-800"}`}>{reasoning}</pre>
116+
</details>
117+
</div>
118+
)}
107119
{hasToolCalls && message.tool_calls && (
108120
<div
109121
className={`mt-2 pt-1 border-t ${

vite-app/src/types/eval-protocol.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export const ChatCompletionContentPartTextParamSchema = z.object({
77
type: z.literal('text').default('text').describe('The type of the content part.')
88
});
99

10+
1011
export const FunctionCallSchema = z.object({
1112
name: z.string(),
1213
arguments: z.string()
@@ -21,6 +22,7 @@ export const ChatCompletionMessageToolCallSchema = z.object({
2122
export const MessageSchema = z.object({
2223
role: z.string().describe('assistant, user, system, tool'),
2324
content: z.union([z.string(), z.array(ChatCompletionContentPartTextParamSchema)]).optional().default('').describe('The content of the message.'),
25+
reasoning_content: z.string().optional().describe('Optional hidden chain-of-thought or reasoning content.'),
2426
name: z.string().optional(),
2527
tool_call_id: z.string().optional(),
2628
tool_calls: z.array(ChatCompletionMessageToolCallSchema).optional(),

0 commit comments

Comments
 (0)