Skip to content

Commit af8ac5c

Browse files
committed
simplify further
1 parent 854cb5c commit af8ac5c

File tree

6 files changed

+126
-238
lines changed

6 files changed

+126
-238
lines changed
Lines changed: 123 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,144 @@
11
import asyncio
2-
import time
3-
from typing import List, Any, cast
2+
from typing import Any, Callable, Dict, List, Optional
43

5-
try:
6-
from langchain_core.messages import BaseMessage as LCBaseMessage, HumanMessage # type: ignore
7-
except ImportError: # pragma: no cover - optional dependency path
8-
# Minimal fallbacks to satisfy typing when langchain is not present
9-
class LCBaseMessage: # type: ignore
10-
content: str
11-
type: str
12-
13-
def __init__(self, content: str = "", msg_type: str = "assistant"):
14-
self.content = content
15-
self.type = msg_type
16-
17-
class HumanMessage(LCBaseMessage): # type: ignore
18-
def __init__(self, content: str):
19-
super().__init__(content=content, msg_type="human")
20-
21-
22-
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.models import EvaluationRow, Status, Message
235
from eval_protocol.pytest.rollout_processor import RolloutProcessor
24-
from eval_protocol.pytest.types import RolloutProcessorConfig
6+
from eval_protocol.pytest.types import CompletionParams, RolloutProcessorConfig
257

268

279
class LangGraphRolloutProcessor(RolloutProcessor):
28-
"""Generic rollout processor for LangChain agents.
29-
30-
Accepts an async factory that returns a target to invoke. The target can be:
31-
- An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph)
32-
- An object with `.ainvoke(payload)`
33-
- A callable that accepts `payload` and returns the result dict
3410
"""
11+
Generic rollout processor for LangGraph graphs.
3512
36-
def __init__(self, get_invoke_target):
37-
self.get_invoke_target = get_invoke_target
38-
39-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
40-
tasks: List[asyncio.Task] = []
41-
42-
async def _process_row(row: EvaluationRow) -> EvaluationRow:
43-
start_time = time.perf_counter()
44-
45-
# Build LC messages from EP row (minimal: last user to HumanMessage)
46-
lm_messages: List[LCBaseMessage] = []
47-
if row.messages:
48-
last_user = [m for m in row.messages if m.role == "user"]
49-
if last_user:
50-
content = last_user[-1].content or ""
51-
if isinstance(content, list):
52-
content = "".join([getattr(p, "text", str(p)) for p in content])
53-
lm_messages.append(HumanMessage(content=str(content)))
54-
if not lm_messages:
55-
lm_messages = [HumanMessage(content="")]
56-
57-
target = await self.get_invoke_target(config)
58-
59-
# Resolve the appropriate async invoke function
60-
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
13+
Configure with:
14+
- to_input(row): build the input payload for graph.ainvoke (default: {"messages": row.messages})
15+
- apply_result(row, result): write graph outputs back onto the row (default: row.messages = result["messages"])
16+
- build_graph_kwargs(cp): map completion_params to graph kwargs (default: {})
6117
62-
async def _invoke_graph(payload):
63-
return await target.graph.ainvoke(payload) # type: ignore[attr-defined]
64-
65-
invoke_fn = _invoke_graph
66-
elif hasattr(target, "ainvoke"):
67-
68-
async def _invoke_direct(payload):
69-
return await target.ainvoke(payload) # type: ignore[attr-defined]
70-
71-
invoke_fn = _invoke_direct
72-
elif callable(target):
73-
74-
async def _invoke_wrapper(payload):
75-
result = target(payload)
76-
if asyncio.iscoroutine(result):
77-
return await result
78-
return result
79-
80-
invoke_fn = _invoke_wrapper
81-
else:
82-
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
83-
84-
result_obj = await invoke_fn({"messages": lm_messages})
85-
if isinstance(result_obj, dict):
86-
result_messages: List[LCBaseMessage] = result_obj.get("messages", [])
87-
else:
88-
result_messages = getattr(result_obj, "messages", [])
18+
Compatible with eval_protocol.pytest.evaluation_test.
19+
"""
8920

90-
def _serialize_message(msg: LCBaseMessage) -> Message:
91-
try:
92-
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser
93-
except ImportError:
94-
content = getattr(msg, "content", "")
95-
return Message(role=getattr(msg, "type", "assistant"), content=str(content))
96-
return _ser(cast(Any, msg))
21+
def __init__(
22+
self,
23+
*,
24+
graph_factory: Callable[[Dict[str, Any]], Any],
25+
to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
26+
apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None,
27+
build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None,
28+
input_key: str = "messages",
29+
output_key: str = "messages",
30+
) -> None:
31+
# Build the graph per-call using completion_params
32+
self._graph_factory = graph_factory
33+
self._to_input = to_input
34+
self._apply_result = apply_result
35+
self._build_graph_kwargs = build_graph_kwargs
36+
self._input_key = input_key
37+
self._output_key = output_key
38+
39+
def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]:
40+
messages = row.messages or []
41+
from eval_protocol.adapters.langchain import serialize_ep_messages_to_lc as _to_lc
42+
43+
return {self._input_key: _to_lc(messages)}
44+
45+
def _default_apply_result(self, row: EvaluationRow, result: Any) -> EvaluationRow:
46+
# Expect dict with output_key → list of messages; coerce to EP messages
47+
maybe_msgs = None
48+
if isinstance(result, dict):
49+
maybe_msgs = result.get(self._output_key)
50+
51+
if maybe_msgs is None:
52+
return row
9753

98-
row.messages = [_serialize_message(m) for m in result_messages]
54+
# If already EP messages, assign directly
55+
if isinstance(maybe_msgs, list) and all(isinstance(m, Message) for m in maybe_msgs):
56+
row.messages = maybe_msgs
57+
return row
9958

100-
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
59+
# Try to convert from LangChain messages; preserve EP Message items as-is
60+
try:
61+
from langchain_core.messages import BaseMessage as _LCBase
62+
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _to_ep
63+
64+
if isinstance(maybe_msgs, list) and any(isinstance(m, _LCBase) for m in maybe_msgs):
65+
converted: List[Message] = []
66+
for m in maybe_msgs:
67+
if isinstance(m, Message):
68+
converted.append(m)
69+
elif isinstance(m, _LCBase):
70+
converted.append(_to_ep(m))
71+
elif isinstance(m, dict):
72+
role = m.get("role") or "assistant"
73+
content = m.get("content")
74+
converted.append(Message(role=role, content=content))
75+
else:
76+
# Best-effort for LC-like objects without importing LC types
77+
role_like = getattr(m, "type", None)
78+
content_like = getattr(m, "content", None)
79+
if content_like is not None:
80+
role_value = "assistant"
81+
if isinstance(role_like, str):
82+
rl = role_like.lower()
83+
if rl in ("human", "user"):
84+
role_value = "user"
85+
elif rl in ("ai", "assistant"):
86+
role_value = "assistant"
87+
elif rl in ("system",):
88+
role_value = "system"
89+
converted.append(Message(role=role_value, content=str(content_like)))
90+
else:
91+
converted.append(Message(role="assistant", content=str(m)))
92+
row.messages = converted
93+
return row
94+
except ImportError:
95+
# If LC is not available, fall back to best-effort below
96+
pass
97+
98+
# Generic best-effort fallback: stringify to assistant messages
99+
if isinstance(maybe_msgs, list):
100+
row.messages = [Message(role="assistant", content=str(m)) for m in maybe_msgs]
101+
else:
102+
row.messages = [Message(role="assistant", content=str(maybe_msgs))]
103+
return row
104+
105+
def _default_build_graph_kwargs(self, _: CompletionParams) -> Dict[str, Any]:
106+
# Keep generic: callers can override to map to their graph’s expected kwargs
107+
return {}
108+
109+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
110+
tasks: List[asyncio.Task[EvaluationRow]] = []
111+
112+
to_input = self._to_input or self._default_to_input
113+
apply_result = self._apply_result or self._default_apply_result
114+
build_kwargs = self._build_graph_kwargs or self._default_build_graph_kwargs
115+
116+
graph_config: Optional[Dict[str, Any]] = None
117+
if config.completion_params:
118+
graph_config = build_kwargs(config.completion_params)
119+
120+
# (Re)build the graph for this call using the graph kwargs
121+
graph_target = self._graph_factory(graph_config or {})
101122

102-
return row
123+
async def _process_row(row: EvaluationRow) -> EvaluationRow:
124+
try:
125+
payload = to_input(row)
126+
if graph_config is not None:
127+
result = await graph_target.ainvoke(payload, config=graph_config)
128+
else:
129+
result = await graph_target.ainvoke(payload)
130+
row = apply_result(row, result)
131+
row.rollout_status = Status.rollout_finished()
132+
return row
133+
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ImportError) as e: # noqa: BLE001
134+
row.rollout_status = Status.rollout_error(str(e))
135+
return row
103136

104137
for r in rows:
105138
tasks.append(asyncio.create_task(_process_row(r)))
106139

107140
return tasks
108141

109142
def cleanup(self) -> None:
143+
# No-op by default
110144
return None

eval_protocol/pytest/langgraph_processor.py

Lines changed: 0 additions & 144 deletions
This file was deleted.
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
11
{"name":"p1","prompt":"Say hello in one sentence","gt":"hello"}
2-
{"name":"p2","prompt":"Introduce yourself briefly","gt":"intro"}
3-
{"name":"p3","prompt":"Respond with a fun fact about space","gt":"space"}

examples/langgraph/test_langgraph_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
44
from eval_protocol.pytest import evaluation_test
5-
from eval_protocol.pytest.langgraph_processor import LangGraphRolloutProcessor
5+
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
66
from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401
77

88
from examples.langgraph.simple_graph import build_simple_graph

0 commit comments

Comments
 (0)