11"""Check: validate a trace against scene expectations."""
22
3- from concurrent .futures import ThreadPoolExecutor , as_completed
43from dataclasses import dataclass , field
54from pathlib import Path
65from typing import Any
76
7+ from .batch import BatchExecutor
88from .metrics import MetricRegistry , MetricResult
99from .models import Expectations
1010from .trace import Trace
@@ -214,6 +214,51 @@ def passed(self) -> bool:
214214 return self .error is None and self .check_result .passed
215215
216216
217+ @dataclass
218+ class _EvaluationTask :
219+ """Internal task descriptor for batch evaluation."""
220+
221+ trace_id : str
222+ trace : Trace
223+ expectations : Expectations
224+
225+
226+ class _EvaluationExecutor (BatchExecutor [_EvaluationTask , EvaluationResult ]):
227+ """Executor for running evaluations in parallel."""
228+
229+ def __init__ (
230+ self ,
231+ metrics : list [str ] | None ,
232+ judge_model : str | None ,
233+ result_storage : Any ,
234+ parallel : int = 1 ,
235+ ):
236+ super ().__init__ (parallel )
237+ self .metrics = metrics
238+ self .judge_model = judge_model
239+ self .result_storage = result_storage
240+
241+ def execute_one (self , item : _EvaluationTask ) -> EvaluationResult :
242+ try :
243+ check_result = evaluate (
244+ trace = item .trace ,
245+ expectations = item .expectations ,
246+ metrics = self .metrics ,
247+ judge_model = self .judge_model ,
248+ )
249+ if self .result_storage :
250+ self .result_storage .save (trace_id = item .trace_id , check_result = check_result )
251+ return EvaluationResult (trace_id = item .trace_id , check_result = check_result )
252+ except Exception as e :
253+ import traceback
254+
255+ return EvaluationResult (
256+ trace_id = item .trace_id ,
257+ check_result = CheckResult (),
258+ error = f"{ type (e ).__name__ } : { e } \n { traceback .format_exc ()} " ,
259+ )
260+
261+
217262def evaluate_batch (
218263 traces : list [Trace ] | str | Path ,
219264 expectations : Expectations | None = None ,
@@ -237,7 +282,7 @@ def evaluate_batch(
237282 """
238283 from .storage import EvaluationStorage , TraceStorage
239284
240- trace_list : list [tuple [ str , Trace , Expectations ] ] = []
285+ tasks : list [_EvaluationTask ] = []
241286
242287 if isinstance (traces , (str , Path )):
243288 path = Path (traces )
@@ -247,50 +292,22 @@ def evaluate_batch(
247292 trace = data ["trace" ]
248293 scene = data ["scene" ]
249294 exp = expectations if expectations else scene .expectations
250- trace_list .append ((trace_id , trace , exp ))
295+ tasks .append (_EvaluationTask (trace_id = trace_id , trace = trace , expectations = exp ))
251296 else :
252297 for i , trace in enumerate (traces ):
253298 trace_id = f"trace_{ i } "
254299 exp = expectations if expectations else Expectations ()
255- trace_list .append ((trace_id , trace , exp ))
300+ tasks .append (_EvaluationTask (trace_id = trace_id , trace = trace , expectations = exp ))
256301
257302 result_storage = None
258303 if output :
259304 result_storage = EvaluationStorage (path = Path (output ))
260305
261- results : list [EvaluationResult ] = []
262-
263- def evaluate_single (trace_id : str , trace : Trace , exp : Expectations ) -> EvaluationResult :
264- try :
265- check_result = evaluate (
266- trace = trace ,
267- expectations = exp ,
268- metrics = metrics ,
269- judge_model = judge_model ,
270- )
271- if result_storage :
272- result_storage .save (trace_id = trace_id , check_result = check_result )
273- return EvaluationResult (trace_id = trace_id , check_result = check_result )
274- except Exception as e :
275- import traceback
306+ executor = _EvaluationExecutor (
307+ metrics = metrics ,
308+ judge_model = judge_model ,
309+ result_storage = result_storage ,
310+ parallel = parallel ,
311+ )
276312
277- return EvaluationResult (
278- trace_id = trace_id ,
279- check_result = CheckResult (),
280- error = f"{ type (e ).__name__ } : { e } \n { traceback .format_exc ()} " ,
281- )
282-
283- if parallel <= 1 :
284- for trace_id , trace , exp in trace_list :
285- result = evaluate_single (trace_id , trace , exp )
286- results .append (result )
287- else :
288- with ThreadPoolExecutor (max_workers = parallel ) as executor :
289- futures = {
290- executor .submit (evaluate_single , trace_id , trace , exp ): trace_id
291- for trace_id , trace , exp in trace_list
292- }
293- for future in as_completed (futures ):
294- results .append (future .result ())
295-
296- return results
313+ return executor .run (tasks )
0 commit comments