Skip to content

Commit 908d14a

Browse files
authored
change langgraph pattern (#173)
* change langgraph pattern * fix per comment
1 parent 25d6c12 commit 908d14a

File tree

5 files changed

+36
-47
lines changed

5 files changed

+36
-47
lines changed

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ class LangGraphRolloutProcessor(RolloutProcessor):
2121
def __init__(
2222
self,
2323
*,
24-
graph_factory: Callable[[Dict[str, Any]], Any],
24+
# Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor)
25+
graph_factory: Callable[[RolloutProcessorConfig], Any],
2526
to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
2627
apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None,
2728
build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None,
2829
input_key: str = "messages",
2930
output_key: str = "messages",
31+
# Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig
32+
build_invoke_config: Optional[Callable[[RolloutProcessorConfig], Dict[str, Any]]] = None,
3033
) -> None:
3134
# Build the graph per-call using completion_params
3235
self._graph_factory = graph_factory
@@ -35,6 +38,7 @@ def __init__(
3538
self._build_graph_kwargs = build_graph_kwargs
3639
self._input_key = input_key
3740
self._output_key = output_key
41+
self._build_invoke_config = build_invoke_config
3842

3943
def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]:
4044
messages = row.messages or []
@@ -121,14 +125,21 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
121125
if config.completion_params:
122126
graph_config = build_kwargs(config.completion_params)
123127

124-
# (Re)build the graph for this call using the graph kwargs
125-
graph_target = self._graph_factory(graph_config or {})
128+
# (Re)build the graph for this call using the full typed config.
129+
graph_target = self._graph_factory(config)
130+
131+
# Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
132+
invoke_config: Optional[Dict[str, Any]] = None
133+
if self._build_invoke_config is not None:
134+
invoke_config = self._build_invoke_config(config)
135+
elif graph_config is not None:
136+
invoke_config = graph_config
126137

127138
async def _process_row(row: EvaluationRow) -> EvaluationRow:
128139
try:
129140
payload = to_input(row)
130-
if graph_config is not None:
131-
result = await graph_target.ainvoke(payload, config=graph_config)
141+
if invoke_config is not None:
142+
result = await graph_target.ainvoke(payload, config=invoke_config)
132143
else:
133144
result = await graph_target.ainvoke(payload)
134145
row = apply_result(row, result)

examples/langgraph/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Package marker for examples.langgraph

examples/langgraph/test_langgraph_rollout.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
44
from eval_protocol.pytest import evaluation_test
55
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
6-
from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401
6+
from eval_protocol.pytest.types import RolloutProcessorConfig
77

8-
from examples.langgraph.simple_graph import build_simple_graph
8+
from .simple_graph import build_simple_graph
99
import os
1010
import pytest
1111

@@ -25,27 +25,15 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
2525
return rows
2626

2727

28-
def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
29-
return {
30-
"config": {
31-
"model": cp.get("model"),
32-
"temperature": cp.get("temperature", 0.0),
33-
}
34-
}
35-
36-
37-
def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
38-
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
39-
model = cfg.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
40-
temperature = cfg.get("temperature", 0.0)
41-
# Provider is fixed to fireworks for this example; can be extended via cfg if needed
28+
def graph_factory(config: RolloutProcessorConfig) -> Any:
29+
cp = config.completion_params or {}
30+
model = cp.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
31+
temperature = cp.get("temperature", 0.0)
32+
# Provider is fixed to fireworks for this example; can be extended via cp if needed
4233
return build_simple_graph(model=model, model_provider="fireworks", temperature=temperature)
4334

4435

45-
processor = LangGraphRolloutProcessor(
46-
graph_factory=graph_factory,
47-
build_graph_kwargs=build_graph_kwargs,
48-
)
36+
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)
4937

5038

5139
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")

examples/langgraph/test_reasoning_rollout.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
44
from eval_protocol.pytest import evaluation_test
55
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
6+
from eval_protocol.pytest.types import RolloutProcessorConfig
67

7-
from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph
8+
from .reasoning_gpt_oss_120b_graph import build_reasoning_graph
89
import os
910
import pytest
1011

@@ -24,21 +25,11 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
2425
return rows
2526

2627

27-
def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
28-
return {
29-
"config": {
30-
"model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"),
31-
"temperature": cp.get("temperature", 0.0),
32-
"reasoning_effort": cp.get("reasoning_effort"),
33-
}
34-
}
35-
36-
37-
def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
38-
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
39-
model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b"
40-
temperature = cfg.get("temperature", 0.0)
41-
reasoning_effort = cfg.get("reasoning_effort")
28+
def graph_factory(config: RolloutProcessorConfig) -> Any:
29+
cp = config.completion_params or {}
30+
model = cp.get("model") or "accounts/fireworks/models/gpt-oss-120b"
31+
temperature = cp.get("temperature", 0.0)
32+
reasoning_effort = cp.get("reasoning_effort")
4233
return build_reasoning_graph(
4334
model=model,
4435
model_provider="fireworks",
@@ -47,10 +38,7 @@ def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
4738
)
4839

4940

50-
processor = LangGraphRolloutProcessor(
51-
graph_factory=graph_factory,
52-
build_graph_kwargs=build_graph_kwargs,
53-
)
41+
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)
5442

5543

5644
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")

tests/pytest/test_langgraph_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from eval_protocol.models import EvaluationRow, Message
99
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
10+
from eval_protocol.pytest.types import RolloutProcessorConfig
1011

1112

1213
class DummyLCMessage:
@@ -25,7 +26,7 @@ async def ainvoke(self, payload: Dict[str, Any], **_: Any):
2526

2627

2728
def _make_processor_with_defaults(out_messages: List[Any]) -> LangGraphRolloutProcessor:
28-
def graph_factory(_: Dict[str, Any]):
29+
def graph_factory(_: RolloutProcessorConfig):
2930
return DummyGraph(out_messages)
3031

3132
return LangGraphRolloutProcessor(graph_factory=graph_factory)
@@ -116,7 +117,7 @@ async def ainvoke(self, payload, **_):
116117
# Ensure our adapter-produced messages flow through
117118
return payload
118119

119-
processor = LangGraphRolloutProcessor(graph_factory=lambda _: EchoGraph())
120+
processor = LangGraphRolloutProcessor(graph_factory=lambda _config: EchoGraph())
120121

121122
# Act
122123
tasks = processor(

0 commit comments

Comments
 (0)