Skip to content

Commit c313ce6

Browse files
committed
fix per comment
1 parent 924e198 commit c313ce6

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ class LangGraphRolloutProcessor(RolloutProcessor):
2121
def __init__(
2222
self,
2323
*,
24-
# Prefer factory that accepts RolloutProcessorConfig for parity with Pydantic pattern.
25-
# For backward compatibility, factories accepting a Dict[str, Any] (graph kwargs) are still supported.
26-
graph_factory: Callable[[Any], Any],
24+
# Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor)
25+
graph_factory: Callable[[RolloutProcessorConfig], Any],
2726
to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
2827
apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None,
2928
build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None,
@@ -126,12 +125,8 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
126125
if config.completion_params:
127126
graph_config = build_kwargs(config.completion_params)
128127

129-
# (Re)build the graph for this call. Prefer passing full config to factory;
130-
# fall back to old dict-based factories if needed.
131-
try:
132-
graph_target = self._graph_factory(config) # type: ignore[arg-type]
133-
except TypeError:
134-
graph_target = self._graph_factory(graph_config or {})
128+
# (Re)build the graph for this call using the full typed config.
129+
graph_target = self._graph_factory(config)
135130

136131
# Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
137132
invoke_config: Optional[Dict[str, Any]] = None

examples/langgraph/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
# Package marker for examples.langgraph
2-

tests/pytest/test_langgraph_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from eval_protocol.models import EvaluationRow, Message
99
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
10+
from eval_protocol.pytest.types import RolloutProcessorConfig
1011

1112

1213
class DummyLCMessage:
@@ -25,7 +26,7 @@ async def ainvoke(self, payload: Dict[str, Any], **_: Any):
2526

2627

2728
def _make_processor_with_defaults(out_messages: List[Any]) -> LangGraphRolloutProcessor:
28-
def graph_factory(_: Dict[str, Any]):
29+
def graph_factory(_: RolloutProcessorConfig):
2930
return DummyGraph(out_messages)
3031

3132
return LangGraphRolloutProcessor(graph_factory=graph_factory)
@@ -116,7 +117,7 @@ async def ainvoke(self, payload, **_):
116117
# Ensure our adapter-produced messages flow through
117118
return payload
118119

119-
processor = LangGraphRolloutProcessor(graph_factory=lambda _: EchoGraph())
120+
processor = LangGraphRolloutProcessor(graph_factory=lambda _config: EchoGraph())
120121

121122
# Act
122123
tasks = processor(

0 commit comments

Comments
 (0)