-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_langchain_rollout_processor.py
More file actions
89 lines (68 loc) · 3.39 KB
/
default_langchain_rollout_processor.py
File metadata and controls
89 lines (68 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import asyncio
from typing import List
try:
from langchain_core.messages import BaseMessage
except Exception: # pragma: no cover - optional dependency path
# Minimal fallback base type to satisfy typing when langchain is not present
class BaseMessage: # type: ignore
pass
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
class LangGraphRolloutProcessor(RolloutProcessor):
"""Generic rollout processor for LangChain agents.
Accepts an async factory that returns a target to invoke. The target can be:
- An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph)
- An object with `.ainvoke(payload)`
- A callable that accepts `payload` and returns the result dict
"""
def __init__(self, get_invoke_target):
self.get_invoke_target = get_invoke_target
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
tasks: List[asyncio.Task] = []
async def _process_row(row: EvaluationRow) -> EvaluationRow:
# Build LC messages from EP row
try:
from langchain_core.messages import HumanMessage
except Exception:
# Fallback minimal message if langchain_core is unavailable
class HumanMessage: # type: ignore
def __init__(self, content: str):
self.content = content
lm_messages: List[BaseMessage] = []
if row.messages:
last_user = [m for m in row.messages if m.role == "user"]
if last_user:
lm_messages.append(HumanMessage(content=last_user[-1].content or ""))
if not lm_messages:
lm_messages = [HumanMessage(content="")] # minimal
target = await self.get_invoke_target(config)
# Resolve the appropriate async invoke function
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
invoke_fn = target.graph.ainvoke
elif hasattr(target, "ainvoke"):
invoke_fn = target.ainvoke
elif callable(target):
async def _invoke_wrapper(payload):
return await target(payload)
invoke_fn = _invoke_wrapper
else:
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
result = await invoke_fn({"messages": lm_messages})
result_messages: List[BaseMessage] = result.get("messages", [])
def _serialize_message(msg: BaseMessage) -> Message:
# Prefer SDK-level serializer
try:
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser
return _ser(msg)
except Exception:
# Minimal fallback: best-effort string content only
content = getattr(msg, "content", "")
return Message(role=getattr(msg, "type", "assistant"), content=str(content))
row.messages = [_serialize_message(m) for m in result_messages]
return row
for r in rows:
tasks.append(asyncio.create_task(_process_row(r)))
return tasks
def cleanup(self) -> None:
return None