|
3 | 3 | import re |
4 | 4 | from typing import Any, Dict, List, Optional |
5 | 5 |
|
6 | | -from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult |
| 6 | +from eval_protocol.models import ( |
| 7 | + EvaluateResult, |
| 8 | + EvaluationRow, |
| 9 | + Message, |
| 10 | + MetricResult, |
| 11 | + ChatCompletionContentPartTextParam, |
| 12 | +) |
7 | 13 | from eval_protocol.pytest.default_single_turn_rollout_process import ( |
8 | 14 | SingleTurnRolloutProcessor, |
9 | 15 | ) |
@@ -31,6 +37,12 @@ def _extract_last_boxed_segment(text: str) -> Optional[str]: |
31 | 37 | return matches[-1] |
32 | 38 |
|
33 | 39 |
|
| 40 | +def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str: |
| 41 | + if isinstance(content, list): |
| 42 | + return "".join([getattr(p, "text", str(p)) for p in content]) |
| 43 | + return str(content or "") |
| 44 | + |
| 45 | + |
34 | 46 | def _cta_process_results(ground_truth: str, llm_answer: str) -> int: |
35 | 47 | parsed_answer = llm_answer |
36 | 48 | if "\\boxed{" in parsed_answer or "\\framebox{" in parsed_answer: |
@@ -275,6 +287,8 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]): |
275 | 287 | return 0 |
276 | 288 |
|
277 | 289 | # Compare |
| 290 | + assert llm_df is not None, "LLM dataframe is None" |
| 291 | + assert gt_df is not None, "GT dataframe is None" |
278 | 292 | try: |
279 | 293 | gt_df.columns = [str(s).strip() for s in gt_df.columns] |
280 | 294 | if "index" in gt_df.columns: |
@@ -420,7 +434,8 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: |
420 | 434 | ) |
421 | 435 | def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: |
422 | 436 | assistant_msgs = [m for m in row.messages if m.role == "assistant"] |
423 | | - content = assistant_msgs[-1].content if assistant_msgs else "" |
| 437 | + raw_content = assistant_msgs[-1].content if assistant_msgs else "" |
| 438 | + content = _coerce_content_to_str(raw_content) |
424 | 439 | payload = _extract_gt(row) |
425 | 440 | gt = payload.get("ground_truth") |
426 | 441 | gt_str = str(gt) if gt is not None else "" |
@@ -462,9 +477,9 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: |
462 | 477 | ) |
463 | 478 | def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: |
464 | 479 | user_msgs = [m for m in row.messages if m.role == "user"] |
465 | | - question = user_msgs[-1].content if user_msgs else "" |
| 480 | + question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "") |
466 | 481 | assistant_msgs = [m for m in row.messages if m.role == "assistant"] |
467 | | - content = assistant_msgs[-1].content if assistant_msgs else "" |
| 482 | + content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "") |
468 | 483 | payload = _extract_gt(row) |
469 | 484 | gt = payload.get("ground_truth") |
470 | 485 |
|
@@ -505,9 +520,9 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: |
505 | 520 | ) |
506 | 521 | def test_livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow: |
507 | 522 | user_msgs = [m for m in row.messages if m.role == "user"] |
508 | | - question = user_msgs[-1].content if user_msgs else "" |
| 523 | + question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "") |
509 | 524 | assistant_msgs = [m for m in row.messages if m.role == "assistant"] |
510 | | - content = assistant_msgs[-1].content if assistant_msgs else "" |
| 525 | + content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "") |
511 | 526 | payload = _extract_gt(row) |
512 | 527 | gt = payload.get("ground_truth") |
513 | 528 | release = payload.get("release") or "" |
|
0 commit comments