Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions eval_protocol/adapters/langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

import os
from typing import List

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

from eval_protocol.models import Message


def _dbg_enabled() -> bool:
return os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1"


def _dbg_print(*args):
if _dbg_enabled():
try:
print(*args)
except Exception:
pass


def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
_dbg_print(
"[EP-Ser] Input LC msg:",
type(msg).__name__,
{
"has_additional_kwargs": isinstance(getattr(msg, "additional_kwargs", None), dict),
"content_type": type(getattr(msg, "content", None)).__name__,
},
)

if isinstance(msg, HumanMessage):
ep_msg = Message(role="user", content=str(msg.content))
_dbg_print("[EP-Ser] -> EP Message:", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
return ep_msg

if isinstance(msg, AIMessage):
content = ""
if isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, list):
parts: List[str] = []
for item in msg.content:
if isinstance(item, dict):
if item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
content = "\n".join(parts)

tool_calls_payload = None

def _normalize_tool_calls(tc_list: list) -> list[dict]:
mapped: List[dict] = []
for call in tc_list:
if not isinstance(call, dict):
continue
try:
call_id = call.get("id") or "toolcall_0"
if isinstance(call.get("function"), dict):
fn = call["function"]
fn_name = fn.get("name") or call.get("name") or "tool"
fn_args = fn.get("arguments")
else:
fn_name = call.get("name") or "tool"
fn_args = call.get("arguments") if call.get("arguments") is not None else call.get("args")
if not isinstance(fn_args, str):
import json as _json

fn_args = _json.dumps(fn_args or {}, ensure_ascii=False)
mapped.append(
{
"id": call_id,
"type": "function",
"function": {"name": fn_name, "arguments": fn_args},
}
)
except Exception:
continue
return mapped

ak = getattr(msg, "additional_kwargs", None)
if isinstance(ak, dict):
tc = ak.get("tool_calls")
if isinstance(tc, list) and tc:
mapped = _normalize_tool_calls(tc)
if mapped:
tool_calls_payload = mapped

if tool_calls_payload is None:
raw_attr_tc = getattr(msg, "tool_calls", None)
if isinstance(raw_attr_tc, list) and raw_attr_tc:
mapped = _normalize_tool_calls(raw_attr_tc)
if mapped:
tool_calls_payload = mapped

# Extract reasoning/thinking parts into reasoning_content
reasoning_content = None
if isinstance(msg.content, list):
collected = [
it.get("thinking", "") for it in msg.content if isinstance(it, dict) and it.get("type") == "thinking"
]
if collected:
reasoning_content = "\n\n".join([s for s in collected if s]) or None

ep_msg = Message(
role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content
)
_dbg_print(
"[EP-Ser] -> EP Message:",
{
"role": ep_msg.role,
"content_len": len(ep_msg.content or ""),
"tool_calls": len(ep_msg.tool_calls or []) if isinstance(ep_msg.tool_calls, list) else 0,
},
)
return ep_msg

if isinstance(msg, ToolMessage):
tool_name = msg.name or "tool"
status = msg.status or "success"
content = str(msg.content)
tool_call_id = getattr(msg, "tool_call_id", None)
ep_msg = Message(
role="tool",
name=tool_name,
tool_call_id=tool_call_id,
content=f'<{tool_name} status="{status}">\n{content}\n</{tool_name}>',
)
_dbg_print(
"[EP-Ser] -> EP Message:", {"role": ep_msg.role, "name": ep_msg.name, "has_id": bool(ep_msg.tool_call_id)}
)
return ep_msg

ep_msg = Message(role=getattr(msg, "type", "assistant"), content=str(getattr(msg, "content", "")))
_dbg_print("[EP-Ser] -> EP Message (fallback):", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
return ep_msg
3 changes: 3 additions & 0 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class Message(BaseModel):
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
default="", description="The content of the message."
)
reasoning_content: Optional[str] = Field(
default=None, description="Optional hidden chain-of-thought or reasoning content."
)
name: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .default_langchain_rollout_processor import LangGraphRolloutProcessor
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
Expand All @@ -22,6 +23,7 @@
"MCPGymRolloutProcessor",
"RolloutProcessor",
"SingleTurnRolloutProcessor",
"LangGraphRolloutProcessor",
"NoOpRolloutProcessor",
"default_dataset_adapter",
"RolloutProcessorConfig",
Expand Down
77 changes: 77 additions & 0 deletions eval_protocol/pytest/default_langchain_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import asyncio
from typing import List

from langchain_core.messages import BaseMessage

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig


class LangGraphRolloutProcessor(RolloutProcessor):
"""Generic rollout processor for LangChain agents.

Accepts an async factory that returns a target to invoke. The target can be:
- An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph)
- An object with `.ainvoke(payload)`
- A callable that accepts `payload` and returns the result dict
"""

def __init__(self, get_invoke_target):
self.get_invoke_target = get_invoke_target

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
tasks: List[asyncio.Task] = []

async def _process_row(row: EvaluationRow) -> EvaluationRow:
# Build LC messages from EP row
from langchain_core.messages import HumanMessage

lm_messages: List[BaseMessage] = []
if row.messages:
last_user = [m for m in row.messages if m.role == "user"]
if last_user:
lm_messages.append(HumanMessage(content=last_user[-1].content or ""))
if not lm_messages:
lm_messages = [HumanMessage(content="")] # minimal

target = await self.get_invoke_target(config)

# Resolve the appropriate async invoke function
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
invoke_fn = target.graph.ainvoke
elif hasattr(target, "ainvoke"):
invoke_fn = target.ainvoke
elif callable(target):

async def _invoke_wrapper(payload):
return await target(payload)

invoke_fn = _invoke_wrapper
else:
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")

result = await invoke_fn({"messages": lm_messages})
result_messages: List[BaseMessage] = result.get("messages", [])

def _serialize_message(msg: BaseMessage) -> Message:
# Prefer SDK-level serializer
try:
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser

return _ser(msg)
except Exception:
# Minimal fallback: best-effort string content only
content = getattr(msg, "content", "")
return Message(role=getattr(msg, "type", "assistant"), content=str(content))

row.messages = [_serialize_message(m) for m in result_messages]
return row

for r in rows:
tasks.append(asyncio.create_task(_process_row(r)))

return tasks

def cleanup(self) -> None:
return None
41 changes: 29 additions & 12 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

from litellm import acompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
from typing import Dict

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
Expand Down Expand Up @@ -71,17 +71,34 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:

converted_tool_calls = None
if tool_calls:
converted_tool_calls = [
ChatCompletionMessageToolCall(
id=tool_call.id,
type=tool_call.type,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
)
for tool_call in tool_calls
]
converted_tool_calls = []
for tool_call in tool_calls:
try:
converted_tool_calls.append(
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
)
except Exception:
# best-effort: fallback to dict form
try:
converted_tool_calls.append(
{
"id": getattr(tool_call, "id", "toolcall_0"),
"type": getattr(tool_call, "type", "function"),
"function": {
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
},
}
)
except Exception:
pass

messages = list(row.messages) + [
Message(
Expand Down
18 changes: 18 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,24 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
)
else:
r.eval_metadata.status = Status.eval_finished()
# Optional debug print for assistant/tool sequence
if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1":
try:
preview = [
{
"role": m.role,
"len": len(m.content or "") if isinstance(m.content, str) else None,
"tool_calls": len(m.tool_calls or [])
if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list)
else 0,
"tool_call_id": getattr(m, "tool_call_id", None),
"name": getattr(m, "name", None),
}
for m in r.messages
]
print("[EP-Log] Row messages:", preview)
except Exception:
pass
active_logger.log(r)

# if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"dataclasses-json>=0.5.7",
"uvicorn>=0.15.0",
"python-dotenv>=0.19.0",
"openai==1.78.1",
"openai>=1.78.1",
"aiosqlite",
"aiohttp",
"mcp>=1.9.2",
Expand Down Expand Up @@ -71,7 +71,7 @@ dev = [
"types-PyYAML",
"types-docker",
"versioneer>=0.20",
"openai==1.78.1",
"openai>=1.78.1",
"pre-commit",
"e2b",
"pytest-cov",
Expand Down
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading