diff --git a/eval/phtest.py b/eval/phtest.py index 7f13d2b..b800ac5 100644 --- a/eval/phtest.py +++ b/eval/phtest.py @@ -2,8 +2,53 @@ # # SPDX-License-Identifier: Apache-2.0 -# TODO(@zhewang2001): Please refactor the corresponding code snippets and then upload it. +import json +from datasets import load_dataset +from rich import print as rprint -def evaluate_phtest(*args): - pass # placeholder +from eval.oracles.check_secqa import evaluate_secqa_answers, is_bad_answer + + +def evaluate_phtest( + generation_path: str, + model: str, +): + dataset = load_dataset("purpcode/PHTest", split="test") + task_id_to_harmfulness = {row["task_id"]: row["Harmfulness"] for row in dataset} + + evaluate_secqa_answers(generation_path, model) + + intermediate_path = generation_path.replace(".jsonl", ".reject-intermediate.jsonl") + with open(intermediate_path, "r") as f: + intermediate_results = [json.loads(line) for line in f] + + harmless_rejected = 0 + harmless_total = 0 + controversial_rejected = 0 + controversial_total = 0 + + for row in intermediate_results: + task_id = row["task_id"] + is_rejected = is_bad_answer(row["messages"][-1]["content"]) + harmfulness = task_id_to_harmfulness.get(task_id, "unknown") + + if harmfulness == "harmless": + harmless_total += 1 + harmless_rejected += int(is_rejected) + elif harmfulness == "controversial": + controversial_total += 1 + controversial_rejected += int(is_rejected) + + total_rejected = harmless_rejected + controversial_rejected + total_all = harmless_total + controversial_total + + rprint( + f"[green]Harmless rejection rate:{harmless_rejected} / {harmless_total} = {100 * harmless_rejected / harmless_total if harmless_total else 0.0:.1f}%[/green]" + ) + rprint( + f"[green]Controversial rejection rate:{controversial_rejected} / {controversial_total} = {100 * controversial_rejected / controversial_total if controversial_total else 0.0:.1f}%[/green]" + ) + rprint( + f"[green]Total rejection rate:{total_rejected} / {total_all} = {100 * total_rejected / total_all if total_all else 0.0:.1f}%[/green]" + )