Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions eval_protocol/pytest/default_langchain_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ class LangGraphRolloutProcessor(RolloutProcessor):
def __init__(
self,
*,
graph_factory: Callable[[Dict[str, Any]], Any],
# Factory must accept RolloutProcessorConfig (parity with Pydantic AI processor)
graph_factory: Callable[[RolloutProcessorConfig], Any],
to_input: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
apply_result: Optional[Callable[[EvaluationRow, Any], EvaluationRow]] = None,
build_graph_kwargs: Optional[Callable[[CompletionParams], Dict[str, Any]]] = None,
input_key: str = "messages",
output_key: str = "messages",
# Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig
build_invoke_config: Optional[Callable[[RolloutProcessorConfig], Dict[str, Any]]] = None,
) -> None:
# Build the graph per-call using completion_params
self._graph_factory = graph_factory
Expand All @@ -35,6 +38,7 @@ def __init__(
self._build_graph_kwargs = build_graph_kwargs
self._input_key = input_key
self._output_key = output_key
self._build_invoke_config = build_invoke_config

def _default_to_input(self, row: EvaluationRow) -> Dict[str, Any]:
messages = row.messages or []
Expand Down Expand Up @@ -121,14 +125,21 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
if config.completion_params:
graph_config = build_kwargs(config.completion_params)

# (Re)build the graph for this call using the graph kwargs
graph_target = self._graph_factory(graph_config or {})
# (Re)build the graph for this call using the full typed config.
graph_target = self._graph_factory(config)

# Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
invoke_config: Optional[Dict[str, Any]] = None
if self._build_invoke_config is not None:
invoke_config = self._build_invoke_config(config)
elif graph_config is not None:
invoke_config = graph_config

async def _process_row(row: EvaluationRow) -> EvaluationRow:
try:
payload = to_input(row)
if graph_config is not None:
result = await graph_target.ainvoke(payload, config=graph_config)
if invoke_config is not None:
result = await graph_target.ainvoke(payload, config=invoke_config)
else:
result = await graph_target.ainvoke(payload)
row = apply_result(row, result)
Expand Down
1 change: 1 addition & 0 deletions examples/langgraph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Package marker for examples.langgraph
28 changes: 8 additions & 20 deletions examples/langgraph/test_langgraph_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig as _UnusedRolloutProcessorConfig # noqa: F401
from eval_protocol.pytest.types import RolloutProcessorConfig

from examples.langgraph.simple_graph import build_simple_graph
from .simple_graph import build_simple_graph
import os
import pytest

Expand All @@ -25,27 +25,15 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
return rows


def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
return {
"config": {
"model": cp.get("model"),
"temperature": cp.get("temperature", 0.0),
}
}


def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
model = cfg.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
temperature = cfg.get("temperature", 0.0)
# Provider is fixed to fireworks for this example; can be extended via cfg if needed
def graph_factory(config: RolloutProcessorConfig) -> Any:
cp = config.completion_params or {}
model = cp.get("model") or "accounts/fireworks/models/kimi-k2-instruct"
temperature = cp.get("temperature", 0.0)
# Provider is fixed to fireworks for this example; can be extended via cp if needed
return build_simple_graph(model=model, model_provider="fireworks", temperature=temperature)


processor = LangGraphRolloutProcessor(
graph_factory=graph_factory,
build_graph_kwargs=build_graph_kwargs,
)
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)


@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
Expand Down
28 changes: 8 additions & 20 deletions examples/langgraph/test_reasoning_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from eval_protocol.models import EvaluationRow, EvaluateResult, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

from examples.langgraph.reasoning_gpt_oss_120b_graph import build_reasoning_graph
from .reasoning_gpt_oss_120b_graph import build_reasoning_graph
import os
import pytest

Expand All @@ -24,21 +25,11 @@ def adapter(raw_rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
return rows


def build_graph_kwargs(cp: Dict[str, Any]) -> Dict[str, Any]:
return {
"config": {
"model": cp.get("model", "accounts/fireworks/models/gpt-oss-120b"),
"temperature": cp.get("temperature", 0.0),
"reasoning_effort": cp.get("reasoning_effort"),
}
}


def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
cfg = graph_kwargs.get("config", {}) if isinstance(graph_kwargs, dict) else {}
model = cfg.get("model") or "accounts/fireworks/models/gpt-oss-120b"
temperature = cfg.get("temperature", 0.0)
reasoning_effort = cfg.get("reasoning_effort")
def graph_factory(config: RolloutProcessorConfig) -> Any:
cp = config.completion_params or {}
model = cp.get("model") or "accounts/fireworks/models/gpt-oss-120b"
temperature = cp.get("temperature", 0.0)
reasoning_effort = cp.get("reasoning_effort")
return build_reasoning_graph(
model=model,
model_provider="fireworks",
Expand All @@ -47,10 +38,7 @@ def graph_factory(graph_kwargs: Dict[str, Any]) -> Any:
)


processor = LangGraphRolloutProcessor(
graph_factory=graph_factory,
build_graph_kwargs=build_graph_kwargs,
)
processor = LangGraphRolloutProcessor(graph_factory=graph_factory)


@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
Expand Down
5 changes: 3 additions & 2 deletions tests/pytest/test_langgraph_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.default_langchain_rollout_processor import LangGraphRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig


class DummyLCMessage:
Expand All @@ -25,7 +26,7 @@ async def ainvoke(self, payload: Dict[str, Any], **_: Any):


def _make_processor_with_defaults(out_messages: List[Any]) -> LangGraphRolloutProcessor:
def graph_factory(_: Dict[str, Any]):
def graph_factory(_: RolloutProcessorConfig):
return DummyGraph(out_messages)

return LangGraphRolloutProcessor(graph_factory=graph_factory)
Expand Down Expand Up @@ -116,7 +117,7 @@ async def ainvoke(self, payload, **_):
# Ensure our adapter-produced messages flow through
return payload

processor = LangGraphRolloutProcessor(graph_factory=lambda _: EchoGraph())
processor = LangGraphRolloutProcessor(graph_factory=lambda _config: EchoGraph())

# Act
tasks = processor(
Expand Down
Loading