@@ -178,6 +178,8 @@ def evaluate(
178178
179179 else :
180180
181+ pass_at_k = dict ()
182+
181183 pass_k = [int (k ) for k in pass_k .split ("," )]
182184
183185 if parallel < 1 :
@@ -207,8 +209,6 @@ def evaluate(
207209
208210 results = compatible_eval_result (results )
209211 else :
210- pass_at_k = dict ()
211-
212212 if check_gt_only :
213213
214214 if gt_pass_rate > 0.99 :
@@ -299,30 +299,30 @@ def stucking_checker():
299299 }
300300 )
301301
302- # Calculate pass@k.
303- total = np .array ([len (r ) for k , r in results ["eval" ].items () if k in problems ])
304- base_correct = []
305-
306- for key , res in results ["eval" ].items ():
307- if key not in problems :
308- continue
309- bc = sum ([r ["status" ] == PASS for r in res ])
310- base_correct .append (bc )
311-
312- base_correct = np .array (base_correct )
313-
314- pass_at_k .update ({
315- f"pass@{ k } " : estimate_pass_at_k (total , base_correct , k ).mean ()
316- for k in pass_k
317- if total .min () >= k
318- })
319-
320- pass_at_k ["model" ] = os .path .basename (samples ).split ("--bigcodebench-" )[0 ]
321- pass_at_k ["split" ] = split
322- pass_at_k ["subset" ] = subset
323- pass_at_k ["calibrated" ] = calibrated
324- pass_at_k ["gt_pass_rate" ] = gt_pass_rate
325- pass_at_k ["failed_tasks" ] = failed_tasks
302+ # Calculate pass@k.
303+ total = np .array ([len (r ) for k , r in results ["eval" ].items () if k in problems ])
304+ base_correct = []
305+
306+ for key , res in results ["eval" ].items ():
307+ if key not in problems :
308+ continue
309+ bc = sum ([r ["status" ] == PASS for r in res ])
310+ base_correct .append (bc )
311+
312+ base_correct = np .array (base_correct )
313+
314+ pass_at_k .update ({
315+ f"pass@{ k } " : estimate_pass_at_k (total , base_correct , k ).mean ()
316+ for k in pass_k
317+ if total .min () >= k
318+ })
319+
320+ pass_at_k ["model" ] = os .path .basename (samples ).split ("--bigcodebench-" )[0 ]
321+ pass_at_k ["split" ] = split
322+ pass_at_k ["subset" ] = subset
323+ pass_at_k ["calibrated" ] = calibrated
324+ pass_at_k ["gt_pass_rate" ] = gt_pass_rate
325+ pass_at_k ["failed_tasks" ] = failed_tasks
326326
327327 extra = subset .capitalize ()
328328 split = split .capitalize ()
0 commit comments