Skip to content

Commit aa634d5

Browse files
authored
fix(evaluate): compute pass_at_k for existing results
1 parent 3828e62 commit aa634d5

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

bigcodebench/evaluate.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def evaluate(
178178

179179
else:
180180

181+
pass_at_k = dict()
182+
181183
pass_k = [int(k) for k in pass_k.split(",")]
182184

183185
if parallel < 1:
@@ -207,8 +209,6 @@ def evaluate(
207209

208210
results = compatible_eval_result(results)
209211
else:
210-
pass_at_k = dict()
211-
212212
if check_gt_only:
213213

214214
if gt_pass_rate > 0.99:
@@ -299,30 +299,30 @@ def stucking_checker():
299299
}
300300
)
301301

302-
# Calculate pass@k.
303-
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
304-
base_correct = []
305-
306-
for key, res in results["eval"].items():
307-
if key not in problems:
308-
continue
309-
bc = sum([r["status"] == PASS for r in res])
310-
base_correct.append(bc)
311-
312-
base_correct = np.array(base_correct)
313-
314-
pass_at_k.update({
315-
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
316-
for k in pass_k
317-
if total.min() >= k
318-
})
319-
320-
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
321-
pass_at_k["split"] = split
322-
pass_at_k["subset"] = subset
323-
pass_at_k["calibrated"] = calibrated
324-
pass_at_k["gt_pass_rate"] = gt_pass_rate
325-
pass_at_k["failed_tasks"] = failed_tasks
302+
# Calculate pass@k.
303+
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
304+
base_correct = []
305+
306+
for key, res in results["eval"].items():
307+
if key not in problems:
308+
continue
309+
bc = sum([r["status"] == PASS for r in res])
310+
base_correct.append(bc)
311+
312+
base_correct = np.array(base_correct)
313+
314+
pass_at_k.update({
315+
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
316+
for k in pass_k
317+
if total.min() >= k
318+
})
319+
320+
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
321+
pass_at_k["split"] = split
322+
pass_at_k["subset"] = subset
323+
pass_at_k["calibrated"] = calibrated
324+
pass_at_k["gt_pass_rate"] = gt_pass_rate
325+
pass_at_k["failed_tasks"] = failed_tasks
326326

327327
extra = subset.capitalize()
328328
split = split.capitalize()

0 commit comments

Comments
 (0)