|
| 1 | +# SPDX-FileCopyrightText: (c) UIUC PurpCode Team |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import json |
| 6 | +import subprocess |
| 7 | + |
| 8 | +from evalplus.sanitize import sanitize |
| 9 | + |
| 10 | +from eval.generate import preprocess_generation |
| 11 | + |
| 12 | +# TODO(@zhewang2001): allow users to play LLM judge based on vLLM, instead of relying on bedrock |
| 13 | +DEFAULT_LLM_JUDGE = "bedrock/us.meta.llama3-3-70b-instruct-v1:0" |
| 14 | + |
| 15 | + |
| 16 | +def to_evalplus_format(generation_path: str) -> str: |
| 17 | + with open(generation_path, "r") as f: |
| 18 | + data = [json.loads(line) for line in f] |
| 19 | + |
| 20 | + evalplus_path = generation_path.replace(".jsonl", ".evalplus.jsonl") |
| 21 | + with open(evalplus_path, "w") as f: |
| 22 | + for item in data: |
| 23 | + f.write( |
| 24 | + json.dumps( |
| 25 | + { |
| 26 | + "task_id": item["task_id"], |
| 27 | + "solution": sanitize(item["messages"][-1]["content"]), |
| 28 | + } |
| 29 | + ) |
| 30 | + + "\n" |
| 31 | + ) |
| 32 | + return evalplus_path |
| 33 | + |
| 34 | + |
| 35 | +def evaluate_main( |
| 36 | + task: str, |
| 37 | + generation_path: str, |
| 38 | + oracle: str = None, |
| 39 | + llm_judge: str = None, |
| 40 | + reference_results_path: str = None, |
| 41 | +): |
| 42 | + if oracle is None: # Guessing oracle |
| 43 | + print(f"Guessing oracle for task {task}...") |
| 44 | + if task in ["purpcode/mbppplus", "purpcode/humanevalplus"]: |
| 45 | + oracle = "evalplus" |
| 46 | + elif task.startswith("purpcode/CyberSecEval"): |
| 47 | + field = task.split("/")[-1].split("-")[-1] |
| 48 | + if field == "SCG": |
| 49 | + oracle = "cyberseceval" |
| 50 | + elif field == "MITRE": |
| 51 | + oracle = "malevent" |
| 52 | + elif field == "FRR": |
| 53 | + oracle = "overrefusal" |
| 54 | + else: |
| 55 | + raise ValueError(f"Unknown dataset: {task}") |
| 56 | + elif task.startswith("purpcode/CodeLMSec"): |
| 57 | + oracle = "codeql" |
| 58 | + elif task.startswith("purpcode/XSCode"): |
| 59 | + oracle = "xscode" |
| 60 | + elif task.startswith("purpcode/XSTest"): |
| 61 | + oracle = "overrefusal" |
| 62 | + elif task.startswith("purpcode/PHTest"): |
| 63 | + oracle = "phtest" |
| 64 | + elif task.startswith("purpcode/CWEval"): |
| 65 | + oracle = "cweval" |
| 66 | + else: |
| 67 | + raise ValueError(f"Unknown oracle for {task = }. Please specify.") |
| 68 | + print(f"Guessed oracle: {oracle}") |
| 69 | + |
| 70 | + generation_path = preprocess_generation(generation_path) |
| 71 | + if oracle == "evalplus": |
| 72 | + generation_path = to_evalplus_format(generation_path) |
| 73 | + dataset = task.split("/")[-1].removesuffix("plus") |
| 74 | + assert dataset in ["mbpp", "humaneval"] |
| 75 | + subprocess.run( |
| 76 | + ["evalplus.evaluate", "--dataset", dataset, "--samples", generation_path] |
| 77 | + ) |
| 78 | + elif oracle == "safety": |
| 79 | + from eval.oracles.compound_safety_check import evaluate_compound_safety |
| 80 | + |
| 81 | + evaluate_compound_safety( |
| 82 | + task=task, |
| 83 | + generation_path=generation_path, |
| 84 | + model=llm_judge or DEFAULT_LLM_JUDGE, |
| 85 | + min_severity_level="MEDIUM", |
| 86 | + reference_results_path=reference_results_path, |
| 87 | + show_separate_stats=True, |
| 88 | + ) |
| 89 | + elif oracle == "cyberseceval": |
| 90 | + from eval.cyberseceval import evaluate_cyberseceval |
| 91 | + |
| 92 | + evaluate_cyberseceval(generation_path=generation_path, task=task) |
| 93 | + elif oracle == "codeguru": |
| 94 | + from eval.oracles.secure_code_oracles import evaluate_secure_code_gen |
| 95 | + |
| 96 | + evaluate_secure_code_gen( |
| 97 | + generation_path=generation_path, |
| 98 | + min_severity_level="MEDIUM", |
| 99 | + analyzers=["codeguru"], |
| 100 | + ) |
| 101 | + elif oracle == "codeql": |
| 102 | + from eval.oracles.secure_code_oracles import evaluate_secure_code_gen |
| 103 | + |
| 104 | + evaluate_secure_code_gen( |
| 105 | + generation_path=generation_path, |
| 106 | + min_severity_level="MEDIUM", |
| 107 | + analyzers=["codeql"], |
| 108 | + ) |
| 109 | + elif oracle == "malevent": |
| 110 | + from eval.oracles.compound_safety_check import ( |
| 111 | + evaluate_malicious_event_assistance, |
| 112 | + ) |
| 113 | + |
| 114 | + evaluate_malicious_event_assistance( |
| 115 | + generation_path=generation_path, |
| 116 | + model=llm_judge or DEFAULT_LLM_JUDGE, |
| 117 | + task=task, |
| 118 | + ) |
| 119 | + elif oracle == "overrefusal": |
| 120 | + from eval.oracles.check_secqa import evaluate_secqa_answers |
| 121 | + |
| 122 | + evaluate_secqa_answers( |
| 123 | + generation_path=generation_path, model=llm_judge or DEFAULT_LLM_JUDGE |
| 124 | + ) |
| 125 | + elif oracle == "xscode": |
| 126 | + from eval.oracles.xscode_overrefuse import evaluate_xscode_answers |
| 127 | + |
| 128 | + evaluate_xscode_answers( |
| 129 | + generation_path=generation_path, model=DEFAULT_LLM_JUDGE |
| 130 | + ) |
| 131 | + |
| 132 | + elif oracle == "phtest": |
| 133 | + from eval.phtest import evaluate_phtest |
| 134 | + |
| 135 | + evaluate_phtest( |
| 136 | + generation_path=generation_path, model=llm_judge or DEFAULT_LLM_JUDGE |
| 137 | + ) |
| 138 | + elif oracle == "cweval": |
| 139 | + from eval.cweval import evaluate_cweval |
| 140 | + |
| 141 | + evaluate_cweval(generation_path=generation_path, task=task) |
| 142 | + else: |
| 143 | + raise ValueError(f"Unknown oracle: {oracle}") |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + from fire import Fire |
| 148 | + |
| 149 | + Fire(evaluate_main) |
0 commit comments