@@ -21,12 +21,16 @@ class LangGraphRolloutProcessor(RolloutProcessor):
2121 def __init__ (
2222 self ,
2323 * ,
24- graph_factory : Callable [[Dict [str , Any ]], Any ],
24+ # Prefer factory that accepts RolloutProcessorConfig for parity with Pydantic pattern.
25+ # For backward compatibility, factories accepting a Dict[str, Any] (graph kwargs) are still supported.
26+ graph_factory : Callable [[Any ], Any ],
2527 to_input : Optional [Callable [[EvaluationRow ], Dict [str , Any ]]] = None ,
2628 apply_result : Optional [Callable [[EvaluationRow , Any ], EvaluationRow ]] = None ,
2729 build_graph_kwargs : Optional [Callable [[CompletionParams ], Dict [str , Any ]]] = None ,
2830 input_key : str = "messages" ,
2931 output_key : str = "messages" ,
32+ # Optional: build per-invoke RunnableConfig dict from full RolloutProcessorConfig
33+ build_invoke_config : Optional [Callable [[RolloutProcessorConfig ], Dict [str , Any ]]] = None ,
3034 ) -> None :
3135 # Build the graph per-call using completion_params
3236 self ._graph_factory = graph_factory
@@ -35,6 +39,7 @@ def __init__(
3539 self ._build_graph_kwargs = build_graph_kwargs
3640 self ._input_key = input_key
3741 self ._output_key = output_key
42+ self ._build_invoke_config = build_invoke_config
3843
3944 def _default_to_input (self , row : EvaluationRow ) -> Dict [str , Any ]:
4045 messages = row .messages or []
@@ -121,14 +126,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
121126 if config .completion_params :
122127 graph_config = build_kwargs (config .completion_params )
123128
124- # (Re)build the graph for this call using the graph kwargs
125- graph_target = self ._graph_factory (graph_config or {})
129+ # (Re)build the graph for this call. Prefer passing full config to factory;
130+ # fall back to old dict-based factories if needed.
131+ try :
132+ graph_target = self ._graph_factory (config ) # type: ignore[arg-type]
133+ except TypeError :
134+ graph_target = self ._graph_factory (graph_config or {})
135+
136+ # Build per-invoke config if provided; otherwise reuse graph_config for backwards compat
137+ invoke_config : Optional [Dict [str , Any ]] = None
138+ if self ._build_invoke_config is not None :
139+ invoke_config = self ._build_invoke_config (config )
140+ elif graph_config is not None :
141+ invoke_config = graph_config
126142
127143 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
128144 try :
129145 payload = to_input (row )
130- if graph_config is not None :
131- result = await graph_target .ainvoke (payload , config = graph_config )
146+ if invoke_config is not None :
147+ result = await graph_target .ainvoke (payload , config = invoke_config )
132148 else :
133149 result = await graph_target .ainvoke (payload )
134150 row = apply_result (row , result )
0 commit comments