diff --git a/eval_protocol/pytest/default_langchain_rollout_processor.py b/eval_protocol/pytest/default_langchain_rollout_processor.py index 95ff0769..4fc24b92 100644 --- a/eval_protocol/pytest/default_langchain_rollout_processor.py +++ b/eval_protocol/pytest/default_langchain_rollout_processor.py @@ -21,12 +21,15 @@ class LangGraphRolloutProcessor(RolloutProcessor): def __init__( self, *, - graph_factory: Callable[[Dict[str, Any]], Any], + # Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor) + graph_factory: Callable[[RolloutProcessorConfig], Any], to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None, build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None, input_key: str = "messages", output_key: str = "messages", + # Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig + build_invoke_config: Optional[Callable[[RolloutProcessorConfig], Dict[str, Any]]] = None, ) -> None: # Build the graph per-call using completion_params self._graph_factory = graph_factory @@ -35,6 +38,7 @@ def __init__( self._build_graph_kwargs = build_graph_kwargs self._input_key = input_key self._output_key = output_key + self._build_invoke_config = build_invoke_config def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]: messages = row.messages or [] @@ -121,14 +125,21 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> if config.completion_params: graph_config = build_kwargs(config.completion_params) - # (Re)build the graph for this call using the graph kwargs - graph_target = self._graph_factory(graph_config or {}) + # (Re)build the graph for this call using the full typed config. + graph_target = self._graph_factory(config) + + # Build per-invoke config if provided; otherwise reuse graph_config for backwards compat + invoke_config: Optional[Dict[str, Any]] = None + if self._build_invoke_config is not None: + invoke_config = self._build_invoke_config(config) + elif graph_config is not None: + invoke_config = graph_config async def _process_row(row: EvaluationRow) -> EvaluationRow: try: payload = to_input(row) - if graph_config is not None: - result = await graph_target.ainvoke(payload, config=graph_config) + if invoke_config is not None: + result = await graph_target.ainvoke(payload, config=invoke_config) else: result = await graph_target.ainvoke(payload) row = apply_result(row, result) diff --git a/examples/langgraph/__init__.py b/examples/langgraph/__init__.py new file mode 100644 index 00000000..5b3fe18e --- /dev/null +++ b/examples/langgraph/__init__.py @@ -0,0 +1 @@ +# Package marker for examples.langgraph diff --git a/examples/langgraph/test_langgraph_rollout.py b/examples/langgraph/test_langgraph_rollout.py index 728000cb..355b32e2 100644 --- a/examples/langgraph/test_langgraph_rollout.py +++ b/examples/langgraph/test_langgraph_rollout.py @@ -3,9 +3,9 @@ from eval_protocol.models import EvaluationRow, EvaluateResult, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor -from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401 +from eval_protocol.pytest.types import RolloutProcessorConfig -from examples.langgraph.simple_graph import build_simple_graph +from .simple_graph import build_simple_graph import os import pytest @@ -25,27 +25,15 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]: return rows -def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]: - return { - "config": { - "model": cp.get("model"), - "temperature": cp.get("temperature", 0.0), - } - } - - -def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: - cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {} - model = cfg.get("model") or "accounts/fireworks/models/kimi-k2-instruct" - temperature = cfg.get("temperature", 0.0) - # Provider is fixed to fireworks for this example; can be extended via cfg if needed +def graph_factory(config: RolloutProcessorConfig) -> Any: + cp = config.completion_params or {} + model = cp.get("model") or "accounts/fireworks/models/kimi-k2-instruct" + temperature = cp.get("temperature", 0.0) + # Provider is fixed to fireworks for this example; can be extended via cp if needed return build_simple_graph(model=model, model_provider="fireworks", temperature=temperature) -processor = LangGraphRolloutProcessor( - graph_factory=graph_factory, - build_graph_kwargs=build_graph_kwargs, -) +processor = LangGraphRolloutProcessor(graph_factory=graph_factory) @pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set") diff --git a/examples/langgraph/test_reasoning_rollout.py b/examples/langgraph/test_reasoning_rollout.py index 21d4c499..3ab4e7a5 100644 --- a/examples/langgraph/test_reasoning_rollout.py +++ b/examples/langgraph/test_reasoning_rollout.py @@ -3,8 +3,9 @@ from eval_protocol.models import EvaluationRow, EvaluateResult, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig -from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph +from .reasoning_gpt_oss_120b_graph import build_reasoning_graph import os import pytest @@ -24,21 +25,11 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]: return rows -def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]: - return { - "config": { - "model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"), - "temperature": cp.get("temperature", 0.0), - "reasoning_effort": cp.get("reasoning_effort"), - } - } - - -def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: - cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {} - model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b" - temperature = cfg.get("temperature", 0.0) - reasoning_effort = cfg.get("reasoning_effort") +def graph_factory(config: RolloutProcessorConfig) -> Any: + cp = config.completion_params or {} + model = cp.get("model") or "accounts/fireworks/models/gpt-oss-120b" + temperature = cp.get("temperature", 0.0) + reasoning_effort = cp.get("reasoning_effort") return build_reasoning_graph( model=model, model_provider="fireworks", @@ -47,10 +38,7 @@ def graph_factory(graph_kwargs: Dict[str, Any]) -> Any: ) -processor = LangGraphRolloutProcessor( - graph_factory=graph_factory, - build_graph_kwargs=build_graph_kwargs, -) +processor = LangGraphRolloutProcessor(graph_factory=graph_factory) @pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set") diff --git a/tests/pytest/test_langgraph_processor.py b/tests/pytest/test_langgraph_processor.py index 702b1c1c..49cdf722 100644 --- a/tests/pytest/test_langgraph_processor.py +++ b/tests/pytest/test_langgraph_processor.py @@ -7,6 +7,7 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig class DummyLCMessage: @@ -25,7 +26,7 @@ async def ainvoke(self, payload: Dict[str, Any], **_: Any): def _make_processor_with_defaults(out_messages: List[Any]) -> LangGraphRolloutProcessor: - def graph_factory(_: Dict[str, Any]): + def graph_factory(_: RolloutProcessorConfig): return DummyGraph(out_messages) return LangGraphRolloutProcessor(graph_factory=graph_factory) @@ -116,7 +117,7 @@ async def ainvoke(self, payload, **_): # Ensure our adapter-produced messages flow through return payload - processor = LangGraphRolloutProcessor(graph_factory=lambda _: EchoGraph()) + processor = LangGraphRolloutProcessor(graph_factory=lambda _config: EchoGraph()) # Act tasks = processor(