2121logger = logging .getLogger (__name__ )
2222
2323# Define the approaches to test
24- # Each approach is (name, description)
24+ # Each approach is (name, description, extra_body_params )
2525APPROACHES = [
26- ("none" , "Baseline without any optimization" ),
27- ("leap" , "LEAP Approach" ),
28- ("rto" , "Round Trip Optimization" ),
29- ("cot_reflection" , "Chain of Thought with Reflection" ),
30- ("self_consistency" , "Self Consistency Check" ),
31- ("plansearch" , "Planning with Search" ),
32- ("re2" , "ReRead Approach" ),
33- ("z3" , "Z3 Solver for Mathematical Problems" ),
34- ("coc" , "Chain of Code" ),
35- ("executecode" , "Execute Code" ),
36- ("spl" , "System Prompt Learning" )
26+ ("none" , "Baseline without any optimization" , {}),
27+ ("leap" , "LEAP Approach" , {}),
28+ ("rto" , "Round Trip Optimization" , {}),
29+ ("cot_reflection" , "Chain of Thought with Reflection" , {}),
30+ ("self_consistency" , "Self Consistency Check" , {}),
31+ ("plansearch" , "Planning with Search" , {}),
32+ ("re2" , "ReRead Approach" , {}),
33+ ("z3" , "Z3 Solver for Mathematical Problems" , {}),
34+ ("coc" , "Chain of Code" , {}),
35+ ("executecode" , "Execute Code" , {}),
36+ ("spl" , "System Prompt Learning" , {})
37+ ]
38+
39+ # Define test-time compute approaches for sequential and parallel scaling
40+ TEST_TIME_COMPUTE_APPROACHES = [
41+ # Baseline
42+ ("none" , "Baseline without any optimization" , {}),
43+
44+ # Sequential test-time compute using thinkdeeper with different thinking budgets
45+ ("thinkdeeper_8k" , "ThinkDeeper with 8K thinking tokens" , {
46+ "decoding" : "thinkdeeper" ,
47+ "max_thinking_tokens" : 8000
48+ }),
49+ ("thinkdeeper_16k" , "ThinkDeeper with 16K thinking tokens" , {
50+ "decoding" : "thinkdeeper" ,
51+ "max_thinking_tokens" : 16000
52+ }),
53+ ("thinkdeeper_32k" , "ThinkDeeper with 32K thinking tokens" , {
54+ "decoding" : "thinkdeeper" ,
55+ "max_thinking_tokens" : 32000
56+ }),
57+
58+ # Parallel test-time compute using majority voting with different k values
59+ ("majority_voting_6" , "Majority Voting with k=6" , {"k" : 6 }),
60+ ("majority_voting_36" , "Majority Voting with k=36" , {"k" : 36 }),
61+ ("majority_voting_60" , "Majority Voting with k=60" , {"k" : 60 }),
3762]
3863
3964def load_optillm_bench () -> datasets .Dataset :
@@ -265,6 +290,7 @@ def evaluate_model(
265290 model : str ,
266291 dataset : datasets .Dataset ,
267292 approach : str ,
293+ approach_extra_body : Dict [str , Any ] = None ,
268294 max_samples : int = None
269295) -> Tuple [Dict [str , float ], List [Dict [str , Any ]]]:
270296 """
@@ -286,8 +312,18 @@ def evaluate_model(
286312 # Prepare the dataset
287313 examples = dataset if max_samples is None else dataset .select (range (max_samples ))
288314
289- # Create model name with approach
290- full_model_name = f"{ approach } -{ model } " if approach != "none" else model
315+ # Create model name with approach - handle special cases
316+ if approach == "none" :
317+ full_model_name = model
318+ elif approach .startswith ("thinkdeeper_" ):
319+ # For thinkdeeper, use base model name (decoding is passed in extra_body)
320+ full_model_name = model
321+ elif approach .startswith ("majority_voting_" ):
322+ # For majority voting, use majority_voting prefix
323+ full_model_name = f"majority_voting-{ model } "
324+ else :
325+ # Standard approach prefix
326+ full_model_name = f"{ approach } -{ model } "
291327
292328 for example in tqdm (examples , desc = f"Evaluating { approach } " ):
293329 try :
@@ -297,6 +333,11 @@ def evaluate_model(
297333 # Record start time
298334 start_time = time .time ()
299335
336+ # Prepare extra_body parameters
337+ extra_body = {"spl_learning" : False }
338+ if approach_extra_body :
339+ extra_body .update (approach_extra_body )
340+
300341 # Make API call
301342 response = client .chat .completions .create (
302343 model = full_model_name ,
@@ -306,7 +347,7 @@ def evaluate_model(
306347 ],
307348 temperature = 0.2 ,
308349 max_tokens = 4096 ,
309- extra_body = { "spl_learning" : False } ,
350+ extra_body = extra_body ,
310351 )
311352
312353 # Calculate time taken
@@ -469,6 +510,8 @@ def main():
469510 help = "Directory to save results" )
470511 parser .add_argument ("--approaches" , nargs = "+" ,
471512 help = "Specific approaches to evaluate (default: all)" )
513+ parser .add_argument ("--test-time-compute" , action = "store_true" ,
514+ help = "Evaluate test-time compute approaches (sequential and parallel scaling)" )
472515 parser .add_argument ("--debug" , action = "store_true" , help = "Enable debug logging" )
473516 args = parser .parse_args ()
474517
0 commit comments