From 924e198a6950c8915a38a7cda91f4eca1e3d0174 Mon Sep 17 00:00:00 2001 From: benjibc Date: Wed, 10 Sep 2025 23:53:19 +0000 Subject: [PATCH 1/2] change langgraph pattern --- .../default_langchain_rollout_processor.py | 26 +++++++++++++---- examples/langgraph/__init__.py | 2 ++ examples/langgraph/test_langgraph_rollout.py | 28 ++++++------------- examples/langgraph/test_reasoning_rollout.py | 28 ++++++------------- 4 files changed, 39 insertions(+), 45 deletions(-) create mode 100644 examples/langgraph/__init__.py diff --git a/eval_protocol/pytest/default_langchain_rollout_processor.py b/eval_protocol/pytest/default_langchain_rollout_processor.py index 95ff0769..2cdf61f2 100644 --- a/eval_protocol/pytest/default_langchain_rollout_processor.py +++ b/eval_protocol/pytest/default_langchain_rollout_processor.py @@ -21,12 +21,16 @@ class LangGraphRolloutProcessor(RolloutProcessor): def __init__( self, *, - graph_factory: Callable[[Dict[str, Any]], Any], + # Prefer factory that accepts RolloutProcessorConfig for parity with Pydantic pattern. + # For backward compatibility, factories accepting a Dict[str, Any] (graph kwargs) are still supported. + graph_factory: Callable[[Any], Any], to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None, build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None, input_key: str = "messages", output_key: str = "messages", + # Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig + build_invoke_config: Optional[Callable[[RolloutProcessorConfig], Dict[str, Any]]] = None, ) -> None: # Build the graph per-call using completion_params self._graph_factory = graph_factory @@ -35,6 +39,7 @@ def __init__( self._build_graph_kwargs = build_graph_kwargs self._input_key = input_key self._output_key = output_key + self._build_invoke_config = build_invoke_config def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]: messages = row.messages or [] @@ -121,14 +126,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> if config.completion_params: graph_config = build_kwargs(config.completion_params) - # (Re)build the graph for this call using the graph kwargs - graph_target = self._graph_factory(graph_config or {}) + # (Re)build the graph for this call. Prefer passing full config to factory; + # fall back to old dict-based factories if needed. + try: + graph_target = self._graph_factory(config) # type: ignore[arg-type] + except TypeError: + graph_target = self._graph_factory(graph_config or {}) + + # Build per-invoke config if provided; otherwise reuse graph_config for backwards compat + invoke_config: Optional[Dict[str, Any]] = None + if self._build_invoke_config is not None: + invoke_config = self._build_invoke_config(config) + elif graph_config is not None: + invoke_config = graph_config async def _process_row(row: EvaluationRow) -> EvaluationRow: try: payload = to_input(row) - if graph_config is not None: - result = await graph_target.ainvoke(payload, config=graph_config) + if invoke_config is not None: + result = await graph_target.ainvoke(payload, config=invoke_config) else: result = await graph_target.ainvoke(payload) row = apply_result(row, result) diff --git a/examples/langgraph/__init__.py b/examples/langgraph/__init__.py new file mode 100644 index 00000000..90fec83b --- /dev/null +++ b/examples/langgraph/__init__.py @@ -0,0 +1,2 @@ +# Package marker for examples.langgraph + diff --git a/examples/langgraph/test_langgraph_rollout.py b/examples/langgraph/test_langgraph_rollout.py index 728000cb..355b32e2 100644 --- a/examples/langgraph/test_langgraph_rollout.py +++ b/examples/langgraph/test_langgraph_rollout.py @@ -3,9 +3,9 @@ from eval_protocol.models import EvaluationRow, EvaluateResult, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor -from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401 +from eval_protocol.pytest.types import RolloutProcessorConfig -from examples.langgraph.simple_graph import build_simple_graph +from .simple_graph import build_simple_graph import os import pytest @@ -25,27 +25,15 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]: return rows -def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]: - return { - "config": { - "model": cp.get("model"), - "temperature": cp.get("temperature", 0.0), - } - } - - -def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: - cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {} - model = cfg.get("model") or "accounts/fireworks/models/kimi-k2-instruct" - temperature = cfg.get("temperature", 0.0) - # Provider is fixed to fireworks for this example; can be extended via cfg if needed +def graph_factory(config: RolloutProcessorConfig) -> Any: + cp = config.completion_params or {} + model = cp.get("model") or "accounts/fireworks/models/kimi-k2-instruct" + temperature = cp.get("temperature", 0.0) + # Provider is fixed to fireworks for this example; can be extended via cp if needed return build_simple_graph(model=model, model_provider="fireworks", temperature=temperature) -processor = LangGraphRolloutProcessor( - graph_factory=graph_factory, - build_graph_kwargs=build_graph_kwargs, -) +processor = LangGraphRolloutProcessor(graph_factory=graph_factory) @pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set") diff --git a/examples/langgraph/test_reasoning_rollout.py b/examples/langgraph/test_reasoning_rollout.py index 21d4c499..3ab4e7a5 100644 --- a/examples/langgraph/test_reasoning_rollout.py +++ b/examples/langgraph/test_reasoning_rollout.py @@ -3,8 +3,9 @@ from eval_protocol.models import EvaluationRow, EvaluateResult, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig -from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph +from .reasoning_gpt_oss_120b_graph import build_reasoning_graph import os import pytest @@ -24,21 +25,11 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]: return rows -def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]: - return { - "config": { - "model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"), - "temperature": cp.get("temperature", 0.0), - "reasoning_effort": cp.get("reasoning_effort"), - } - } - - -def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: - cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {} - model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b" - temperature = cfg.get("temperature", 0.0) - reasoning_effort = cfg.get("reasoning_effort") +def graph_factory(config: RolloutProcessorConfig) -> Any: + cp = config.completion_params or {} + model = cp.get("model") or "accounts/fireworks/models/gpt-oss-120b" + temperature = cp.get("temperature", 0.0) + reasoning_effort = cp.get("reasoning_effort") return build_reasoning_graph( model=model, model_provider="fireworks", @@ -47,10 +38,7 @@ def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: ) -processor = LangGraphRolloutProcessor( - graph_factory=graph_factory, - build_graph_kwargs=build_graph_kwargs, -) +processor = LangGraphRolloutProcessor(graph_factory=graph_factory) @pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set") From c313ce673546f1d400e430d155a003e2b9be6f2c Mon Sep 17 00:00:00 2001 From: benjibc Date: Sat, 13 Sep 2025 00:33:05 +0000 Subject: [PATCH 2/2] fix per comment --- .../pytest/default_langchain_rollout_processor.py | 13 ++++--------- examples/langgraph/__init__.py | 1 - tests/pytest/test_langgraph_processor.py | 5 +++-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/eval_protocol/pytest/default_langchain_rollout_processor.py b/eval_protocol/pytest/default_langchain_rollout_processor.py index 2cdf61f2..4fc24b92 100644 --- a/eval_protocol/pytest/default_langchain_rollout_processor.py +++ b/eval_protocol/pytest/default_langchain_rollout_processor.py @@ -21,9 +21,8 @@ class LangGraphRolloutProcessor(RolloutProcessor): def __init__( self, *, - # Prefer factory that accepts RolloutProcessorConfig for parity with Pydantic pattern. - # For backward compatibility, factories accepting a Dict[str, Any] (graph kwargs) are still supported. - graph_factory: Callable[[Any], Any], + # Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor) + graph_factory: Callable[[RolloutProcessorConfig], Any], to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None, build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None, @@ -126,12 +125,8 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> if config.completion_params: graph_config = build_kwargs(config.completion_params) - # (Re)build the graph for this call. Prefer passing full config to factory; - # fall back to old dict-based factories if needed. - try: - graph_target = self._graph_factory(config) # type: ignore[arg-type] - except TypeError: - graph_target = self._graph_factory(graph_config or {}) + # (Re)build the graph for this call using the full typed config. + graph_target = self._graph_factory(config) # Build per-invoke config if provided; otherwise reuse graph_config for backwards compat invoke_config: Optional[Dict[str, Any]] = None diff --git a/examples/langgraph/__init__.py b/examples/langgraph/__init__.py index 90fec83b..5b3fe18e 100644 --- a/examples/langgraph/__init__.py +++ b/examples/langgraph/__init__.py @@ -1,2 +1 @@ # Package marker for examples.langgraph - diff --git a/tests/pytest/test_langgraph_processor.py b/tests/pytest/test_langgraph_processor.py index 702b1c1c..49cdf722 100644 --- a/tests/pytest/test_langgraph_processor.py +++ b/tests/pytest/test_langgraph_processor.py @@ -7,6 +7,7 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig class DummyLCMessage: @@ -25,7 +26,7 @@ async def ainvoke(self, payload: Dict[str, Any], **_: Any): def _make_processor_with_defaults(out_messages: List[Any]) -> LangGraphRolloutProcessor: - def graph_factory(_: Dict[str, Any]): + def graph_factory(_: RolloutProcessorConfig): return DummyGraph(out_messages) return LangGraphRolloutProcessor(graph_factory=graph_factory) @@ -116,7 +117,7 @@ async def ainvoke(self, payload, **_): # Ensure our adapter-produced messages flow through return payload - processor = LangGraphRolloutProcessor(graph_factory=lambda _: EchoGraph()) + processor = LangGraphRolloutProcessor(graph_factory=lambda _config: EchoGraph()) # Act tasks = processor(