|
1 | 1 | import asyncio |
2 | | -import time |
3 | | -from typing import List, Any, cast |
| 2 | +from typing import Any, Callable, Dict, List, Optional |
4 | 3 |
|
5 | | -try: |
6 | | - from langchain_core.messages import BaseMessage as LCBaseMessage, HumanMessage # type: ignore |
7 | | -except ImportError: # pragma: no cover - optional dependency path |
8 | | - # Minimal fallbacks to satisfy typing when langchain is not present |
9 | | - class LCBaseMessage: # type: ignore |
10 | | - content: str |
11 | | - type: str |
12 | | - |
13 | | - def __init__(self, content: str = "", msg_type: str = "assistant"): |
14 | | - self.content = content |
15 | | - self.type = msg_type |
16 | | - |
17 | | - class HumanMessage(LCBaseMessage): # type: ignore |
18 | | - def __init__(self, content: str): |
19 | | - super().__init__(content=content, msg_type="human") |
20 | | - |
21 | | - |
22 | | -from eval_protocol.models import EvaluationRow, Message |
| 4 | +from eval_protocol.models import EvaluationRow, Status, Message |
23 | 5 | from eval_protocol.pytest.rollout_processor import RolloutProcessor |
24 | | -from eval_protocol.pytest.types import RolloutProcessorConfig |
| 6 | +from eval_protocol.pytest.types import CompletionParams, RolloutProcessorConfig |
25 | 7 |
|
26 | 8 |
|
27 | 9 | class LangGraphRolloutProcessor(RolloutProcessor): |
28 | | - """Generic rollout processor for LangChain agents. |
29 | | -
|
30 | | - Accepts an async factory that returns a target to invoke. The target can be: |
31 | | - - An object with `.graph.ainvoke(payload)` (e.g., LangGraph compiled graph) |
32 | | - - An object with `.ainvoke(payload)` |
33 | | - - A callable that accepts `payload` and returns the result dict |
34 | 10 | """ |
| 11 | + Generic rollout processor for LangGraph graphs. |
35 | 12 |
|
36 | | - def __init__(self, get_invoke_target): |
37 | | - self.get_invoke_target = get_invoke_target |
38 | | - |
39 | | - def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig): |
40 | | - tasks: List[asyncio.Task] = [] |
41 | | - |
42 | | - async def _process_row(row: EvaluationRow) -> EvaluationRow: |
43 | | - start_time = time.perf_counter() |
44 | | - |
45 | | - # Build LC messages from EP row (minimal: last user to HumanMessage) |
46 | | - lm_messages: List[LCBaseMessage] = [] |
47 | | - if row.messages: |
48 | | - last_user = [m for m in row.messages if m.role == "user"] |
49 | | - if last_user: |
50 | | - content = last_user[-1].content or "" |
51 | | - if isinstance(content, list): |
52 | | - content = "".join([getattr(p, "text", str(p)) for p in content]) |
53 | | - lm_messages.append(HumanMessage(content=str(content))) |
54 | | - if not lm_messages: |
55 | | - lm_messages = [HumanMessage(content="")] |
56 | | - |
57 | | - target = await self.get_invoke_target(config) |
58 | | - |
59 | | - # Resolve the appropriate async invoke function |
60 | | - if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"): |
| 13 | + Configure with: |
| 14 | + - to_input(row): build the input payload for graph.ainvoke (default: {"messages": row.messages}) |
| 15 | + - apply_result(row, result): write graph outputs back onto the row (default: row.messages = result["messages"]) |
| 16 | + - build_graph_kwargs(cp): map completion_params to graph kwargs (default: {}) |
61 | 17 |
|
62 | | - async def _invoke_graph(payload): |
63 | | - return await target.graph.ainvoke(payload) # type: ignore[attr-defined] |
64 | | - |
65 | | - invoke_fn = _invoke_graph |
66 | | - elif hasattr(target, "ainvoke"): |
67 | | - |
68 | | - async def _invoke_direct(payload): |
69 | | - return await target.ainvoke(payload) # type: ignore[attr-defined] |
70 | | - |
71 | | - invoke_fn = _invoke_direct |
72 | | - elif callable(target): |
73 | | - |
74 | | - async def _invoke_wrapper(payload): |
75 | | - result = target(payload) |
76 | | - if asyncio.iscoroutine(result): |
77 | | - return await result |
78 | | - return result |
79 | | - |
80 | | - invoke_fn = _invoke_wrapper |
81 | | - else: |
82 | | - raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor") |
83 | | - |
84 | | - result_obj = await invoke_fn({"messages": lm_messages}) |
85 | | - if isinstance(result_obj, dict): |
86 | | - result_messages: List[LCBaseMessage] = result_obj.get("messages", []) |
87 | | - else: |
88 | | - result_messages = getattr(result_obj, "messages", []) |
| 18 | + Compatible with eval_protocol.pytest.evaluation_test. |
| 19 | + """ |
89 | 20 |
|
90 | | - def _serialize_message(msg: LCBaseMessage) -> Message: |
91 | | - try: |
92 | | - from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser |
93 | | - except ImportError: |
94 | | - content = getattr(msg, "content", "") |
95 | | - return Message(role=getattr(msg, "type", "assistant"), content=str(content)) |
96 | | - return _ser(cast(Any, msg)) |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + *, |
| 24 | + graph_factory: Callable[[Dict[str, Any]], Any], |
| 25 | + to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, |
| 26 | + apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None, |
| 27 | + build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None, |
| 28 | + input_key: str = "messages", |
| 29 | + output_key: str = "messages", |
| 30 | + ) -> None: |
| 31 | + # Build the graph per-call using completion_params |
| 32 | + self._graph_factory = graph_factory |
| 33 | + self._to_input = to_input |
| 34 | + self._apply_result = apply_result |
| 35 | + self._build_graph_kwargs = build_graph_kwargs |
| 36 | + self._input_key = input_key |
| 37 | + self._output_key = output_key |
| 38 | + |
| 39 | + def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]: |
| 40 | + messages = row.messages or [] |
| 41 | + from eval_protocol.adapters.langchain import serialize_ep_messages_to_lc as _to_lc |
| 42 | + |
| 43 | + return {self._input_key: _to_lc(messages)} |
| 44 | + |
| 45 | + def _default_apply_result(self, row: EvaluationRow, result: Any) -> EvaluationRow: |
| 46 | + # Expect dict with output_key → list of messages; coerce to EP messages |
| 47 | + maybe_msgs = None |
| 48 | + if isinstance(result, dict): |
| 49 | + maybe_msgs = result.get(self._output_key) |
| 50 | + |
| 51 | + if maybe_msgs is None: |
| 52 | + return row |
97 | 53 |
|
98 | | - row.messages = [_serialize_message(m) for m in result_messages] |
| 54 | + # If already EP messages, assign directly |
| 55 | + if isinstance(maybe_msgs, list) and all(isinstance(m, Message) for m in maybe_msgs): |
| 56 | + row.messages = maybe_msgs |
| 57 | + return row |
99 | 58 |
|
100 | | - row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 59 | + # Try to convert from LangChain messages; preserve EP Message items as-is |
| 60 | + try: |
| 61 | + from langchain_core.messages import BaseMessage as _LCBase |
| 62 | + from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _to_ep |
| 63 | + |
| 64 | + if isinstance(maybe_msgs, list) and any(isinstance(m, _LCBase) for m in maybe_msgs): |
| 65 | + converted: List[Message] = [] |
| 66 | + for m in maybe_msgs: |
| 67 | + if isinstance(m, Message): |
| 68 | + converted.append(m) |
| 69 | + elif isinstance(m, _LCBase): |
| 70 | + converted.append(_to_ep(m)) |
| 71 | + elif isinstance(m, dict): |
| 72 | + role = m.get("role") or "assistant" |
| 73 | + content = m.get("content") |
| 74 | + converted.append(Message(role=role, content=content)) |
| 75 | + else: |
| 76 | + # Best-effort for LC-like objects without importing LC types |
| 77 | + role_like = getattr(m, "type", None) |
| 78 | + content_like = getattr(m, "content", None) |
| 79 | + if content_like is not None: |
| 80 | + role_value = "assistant" |
| 81 | + if isinstance(role_like, str): |
| 82 | + rl = role_like.lower() |
| 83 | + if rl in ("human", "user"): |
| 84 | + role_value = "user" |
| 85 | + elif rl in ("ai", "assistant"): |
| 86 | + role_value = "assistant" |
| 87 | + elif rl in ("system",): |
| 88 | + role_value = "system" |
| 89 | + converted.append(Message(role=role_value, content=str(content_like))) |
| 90 | + else: |
| 91 | + converted.append(Message(role="assistant", content=str(m))) |
| 92 | + row.messages = converted |
| 93 | + return row |
| 94 | + except ImportError: |
| 95 | + # If LC is not available, fall back to best-effort below |
| 96 | + pass |
| 97 | + |
| 98 | + # Generic best-effort fallback: stringify to assistant messages |
| 99 | + if isinstance(maybe_msgs, list): |
| 100 | + row.messages = [Message(role="assistant", content=str(m)) for m in maybe_msgs] |
| 101 | + else: |
| 102 | + row.messages = [Message(role="assistant", content=str(maybe_msgs))] |
| 103 | + return row |
| 104 | + |
| 105 | + def _default_build_graph_kwargs(self, _: CompletionParams) -> Dict[str, Any]: |
| 106 | + # Keep generic: callers can override to map to their graph’s expected kwargs |
| 107 | + return {} |
| 108 | + |
| 109 | + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: |
| 110 | + tasks: List[asyncio.Task[EvaluationRow]] = [] |
| 111 | + |
| 112 | + to_input = self._to_input or self._default_to_input |
| 113 | + apply_result = self._apply_result or self._default_apply_result |
| 114 | + build_kwargs = self._build_graph_kwargs or self._default_build_graph_kwargs |
| 115 | + |
| 116 | + graph_config: Optional[Dict[str, Any]] = None |
| 117 | + if config.completion_params: |
| 118 | + graph_config = build_kwargs(config.completion_params) |
| 119 | + |
| 120 | + # (Re)build the graph for this call using the graph kwargs |
| 121 | + graph_target = self._graph_factory(graph_config or {}) |
101 | 122 |
|
102 | | - return row |
| 123 | + async def _process_row(row: EvaluationRow) -> EvaluationRow: |
| 124 | + try: |
| 125 | + payload = to_input(row) |
| 126 | + if graph_config is not None: |
| 127 | + result = await graph_target.ainvoke(payload, config=graph_config) |
| 128 | + else: |
| 129 | + result = await graph_target.ainvoke(payload) |
| 130 | + row = apply_result(row, result) |
| 131 | + row.rollout_status = Status.rollout_finished() |
| 132 | + return row |
| 133 | + except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ImportError) as e: # noqa: BLE001 |
| 134 | + row.rollout_status = Status.rollout_error(str(e)) |
| 135 | + return row |
103 | 136 |
|
104 | 137 | for r in rows: |
105 | 138 | tasks.append(asyncio.create_task(_process_row(r))) |
106 | 139 |
|
107 | 140 | return tasks |
108 | 141 |
|
109 | 142 | def cleanup(self) -> None: |
| 143 | + # No-op by default |
110 | 144 | return None |
0 commit comments