Skip to content

Commit 924e198

Browse files
committed
change langgraph pattern
1 parent 8101180 commit 924e198

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ class LangGraphRolloutProcessor(RolloutProcessor):
2121
def __init__(
2222
self,
2323
*,
24-
graph_factory: Callable[[Dict[str, Any]], Any],
24+
# Prefer factory that accepts RolloutProcessorConfig for parity with Pydantic pattern.
25+
# For backward compatibility, factories accepting a Dict[str, Any] (graph kwargs) are still supported.
26+
graph_factory: Callable[[Any], Any],
2527
to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
2628
apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None,
2729
build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None,
2830
input_key: str = "messages",
2931
output_key: str = "messages",
32+
# Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig
33+
build_invoke_config: Optional[Callable[[RolloutProcessorConfig], Dict[str, Any]]] = None,
3034
) -> None:
3135
# Build the graph per-call using completion_params
3236
self._graph_factory = graph_factory
@@ -35,6 +39,7 @@ def __init__(
3539
self._build_graph_kwargs = build_graph_kwargs
3640
self._input_key = input_key
3741
self._output_key = output_key
42+
self._build_invoke_config = build_invoke_config
3843

3944
def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]:
4045
messages = row.messages or []
@@ -121,14 +126,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
121126
if config.completion_params:
122127
graph_config = build_kwargs(config.completion_params)
123128

124-
# (Re)build the graph for this call using the graph kwargs
125-
graph_target = self._graph_factory(graph_config or {})
129+
# (Re)build the graph for this call. Prefer passing full config to factory;
130+
# fall back to old dict-based factories if needed.
131+
try:
132+
graph_target = self._graph_factory(config) # type: ignore[arg-type]
133+
except TypeError:
134+
graph_target = self._graph_factory(graph_config or {})
135+
136+
# Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
137+
invoke_config: Optional[Dict[str, Any]] = None
138+
if self._build_invoke_config is not None:
139+
invoke_config = self._build_invoke_config(config)
140+
elif graph_config is not None:
141+
invoke_config = graph_config
126142

127143
async def _process_row(row: EvaluationRow) -> EvaluationRow:
128144
try:
129145
payload = to_input(row)
130-
if graph_config is not None:
131-
result = await graph_target.ainvoke(payload, config=graph_config)
146+
if invoke_config is not None:
147+
result = await graph_target.ainvoke(payload, config=invoke_config)
132148
else:
133149
result = await graph_target.ainvoke(payload)
134150
row = apply_result(row, result)

examples/langgraph/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Package marker for examples.langgraph
2+

examples/langgraph/test_langgraph_rollout.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
44
from eval_protocol.pytest import evaluation_test
55
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
6-
from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401
6+
from eval_protocol.pytest.types import RolloutProcessorConfig
77

8-
from examples.langgraph.simple_graph import build_simple_graph
8+
from .simple_graph import build_simple_graph
99
import os
1010
import pytest
1111

@@ -25,27 +25,15 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
2525
return rows
2626

2727

28-
def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
29-
return {
30-
"config": {
31-
"model": cp.get("model"),
32-
"temperature": cp.get("temperature", 0.0),
33-
}
34-
}
35-
36-
37-
def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
38-
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
39-
model = cfg.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
40-
temperature = cfg.get("temperature", 0.0)
41-
# Provider is fixed to fireworks for this example; can be extended via cfg if needed
28+
def graph_factory(config: RolloutProcessorConfig) -> Any:
29+
cp = config.completion_params or {}
30+
model = cp.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
31+
temperature = cp.get("temperature", 0.0)
32+
# Provider is fixed to fireworks for this example; can be extended via cp if needed
4233
return build_simple_graph(model=model, model_provider="fireworks", temperature=temperature)
4334

4435

45-
processor = LangGraphRolloutProcessor(
46-
graph_factory=graph_factory,
47-
build_graph_kwargs=build_graph_kwargs,
48-
)
36+
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)
4937

5038

5139
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")

examples/langgraph/test_reasoning_rollout.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
44
from eval_protocol.pytest import evaluation_test
55
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
6+
from eval_protocol.pytest.types import RolloutProcessorConfig
67

7-
from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph
8+
from .reasoning_gpt_oss_120b_graph import build_reasoning_graph
89
import os
910
import pytest
1011

@@ -24,21 +25,11 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
2425
return rows
2526

2627

27-
def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
28-
return {
29-
"config": {
30-
"model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"),
31-
"temperature": cp.get("temperature", 0.0),
32-
"reasoning_effort": cp.get("reasoning_effort"),
33-
}
34-
}
35-
36-
37-
def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
38-
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
39-
model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b"
40-
temperature = cfg.get("temperature", 0.0)
41-
reasoning_effort = cfg.get("reasoning_effort")
28+
def graph_factory(config: RolloutProcessorConfig) -> Any:
29+
cp = config.completion_params or {}
30+
model = cp.get("model") or "accounts/fireworks/models/gpt-oss-120b"
31+
temperature = cp.get("temperature", 0.0)
32+
reasoning_effort = cp.get("reasoning_effort")
4233
return build_reasoning_graph(
4334
model=model,
4435
model_provider="fireworks",
@@ -47,10 +38,7 @@ def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
4738
)
4839

4940

50-
processor = LangGraphRolloutProcessor(
51-
graph_factory=graph_factory,
52-
build_graph_kwargs=build_graph_kwargs,
53-
)
41+
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)
5442

5543

5644
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")

0 commit comments

Comments
 (0)