-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
338 lines (282 loc) · 15.6 KB
/
utils.py
File metadata and controls
338 lines (282 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import os
import json
import math
import torch
import numpy as np
from pathlib import Path
from datetime import datetime
from sklearn.metrics import roc_curve, auc
from deepscaler.rewards.math_utils.utils import extract_answer
def load_datasets(args):
questions = []; answers = []; options = []; memberships = []
if args.dataset_name in ["amc23", "aime24", "aime25", "olympiadbench", "minerva_math", "olympiadbench", "gpqa_diamond"]:
with open(os.path.join(args.dataset_path, args.dataset_name, "test.jsonl"), "r", encoding="utf-8") as f:
for line in f:
record = json.loads(line)
try:
questions.append(record["question"])
except:
questions.append(record["problem"])
if args.dataset_name in ["aime24", "aime25", "amc23", "math500"]:
answers.append(record["answer"])
elif "minerva_math" in args.dataset_name:
answers.append(extract_answer(record["solution"]))
elif "olympiadbench" in args.dataset_name:
answers.append(record["final_answer"][0])
elif "gpqa_diamond" in args.dataset_name:
answers.append(record["answer"]), options.append(record["options"])
memberships.append(record["membership"])
assert len(questions) == len(answers)
# Shard the datasets
number_samples = math.ceil(len(questions)/args.global_size)
start_shard_index = (args.sharding-1)*number_samples
end_shard_index = args.sharding*number_samples
questions = questions[start_shard_index:end_shard_index]
answers = answers[start_shard_index:end_shard_index]
options = options[start_shard_index:end_shard_index]
memberships = memberships[start_shard_index:end_shard_index]
assert len(questions) == len(answers) == len(memberships)
if "gpqa_diamond" in args.dataset_name:
assert len(options) == len(memberships)
else:
assert NotImplementedError
return questions, answers, options, memberships
def prepare_format(questions, options, tokenizer, args):
prompts = []
if "gpqa_diamond" not in args.dataset_name:
for question in questions:
if "qwen" in args.model_name and "deepseek" not in args.model_name.lower():
message = f"<|im_start|>user\n{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n<|im_start|>assistant\n"
elif "llama" in args.model_name and "deepseek" not in args.model_name.lower():
message = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif "deepseek" in args.model_name.lower():
message = f"<|begin▁of▁sentence|><|User|>{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|Assistant|><think>\n"
prompts.append(message)
else:
for question, option in zip(questions, options):
prompt = question + "\n" + option
full_prompt = f"Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. {prompt}"
if "qwen" in args.model_name and "deepseek" not in args.model_name.lower():
message = f"<|im_start|>user\n{full_prompt}<|im_end|>\n<|im_start|>assistant\n"
elif "llama" in args.model_name and "deepseek" not in args.model_name.lower():
message = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{full_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif "deepseek" in args.model_name.lower():
message = f"<|begin▁of▁sentence|><|User|>{full_prompt}<|Assistant|><think>\n"
prompts.append(message)
return prompts
def save_sharding_outputs(outputs, answers, args, tokenizer):
results = []
for index, output in enumerate(outputs):
logprob_list = []
token_ids_list = []
response_list = []
for i in range(args.rollout_number):
logprob = list(map(lambda d: next(iter(d.values())).logprob, output.outputs[i].logprobs))
token_ids = output.outputs[i].token_ids
assert len(logprob) == len(token_ids)
logprob_list.append(logprob)
token_ids_list.append(token_ids)
response_list.append(tokenizer.decode(output.outputs[i].token_ids))
prompt_logprobs = output.prompt_logprobs[1:]
prompt_logprobs = list(map(lambda d: next(iter(d.values())).logprob, prompt_logprobs))
assert len(output.prompt_token_ids) - 1 == len(prompt_logprobs)
results.append({
"prompt_token_ids": output.prompt_token_ids, "prompt": output.prompt, "prompt_logprob": prompt_logprobs,
"response": response_list, "token_ids": token_ids_list, "logprob": logprob_list,
"answer": answers[index],
})
assert len(logprob_list) == len(token_ids_list) == len(response_list)
now = datetime.now()
formatted_time = now.strftime("%Y-%m-%d %H:%M")
results_path = os.path.join(args.save_path, args.model_name.split("/")[-1], args.dataset_name, args.mia)
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "sharding_{}_{}.jsonl".format(args.sharding, formatted_time)), "w", encoding="utf-8") as f:
for rec in results:
json.dump(rec, f, ensure_ascii=False)
f.write("\n")
if args.sharding == 1:
with open(os.path.join(results_path, "args_{}.txt".format(formatted_time)), "w") as f:
f.write(str(args))
print("Sharding [{}/{}] | Saving results to: {}".format(args.sharding, args.global_size, results_path))
def save_sharding_outputs_no_logprob(outputs, answers, memberships, args, tokenizer):
results = []
for index, output in enumerate(outputs):
# logprob_list = []
token_ids_list = []
response_list = []
for i in range(args.rollout_number):
token_ids = output.outputs[i].token_ids
token_ids_list.append(token_ids)
response_list.append(tokenizer.decode(output.outputs[i].token_ids))
result_dict = {
"prompt_token_ids": output.prompt_token_ids, "prompt": output.prompt,
"response": response_list, "token_ids": token_ids_list,
"answer": answers[index], "membership": memberships[index]
}
results.append(result_dict)
now = datetime.now()
formatted_time = now.strftime("%Y-%m-%d %H:%M")
results_path = os.path.join(args.save_path, args.model_name.split("/")[-1], args.dataset_name, args.mia)
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "sharding_{}_{}.jsonl".format(args.sharding, formatted_time)), "w", encoding="utf-8") as f:
for rec in results:
json.dump(rec, f, ensure_ascii=False)
f.write("\n")
if args.sharding == 1:
with open(os.path.join(results_path, "args_{}.txt".format(formatted_time)), "w") as f:
f.write(str(args))
print("Sharding [{}/{}] | Saving results to: {}".format(args.sharding, args.global_size, results_path))
def save_sharding_records_result(labels, mia_values, mia, args):
values_cpu = []
for x in mia_values:
values_cpu.append(x.detach().cpu().float().numpy())
# Handle shuffle subfolders
base_path = os.path.join(args.results_path, args.model_name.split("/")[-1], args.dataset_name, "npz_rollout_{}".format(args.num_responses))
os.makedirs(base_path, exist_ok=True)
np.savez(os.path.join(base_path, f"{mia}_shard_{args.sharding}.npz"), list1=labels, list2=values_cpu)
print("Records Sharding [{}/{}] | MIA: {} | Saving results to: {}".format(args.sharding, args.global_size, mia, base_path))
def get_detection_metrics(labels, mia_values, mia):
# ori_labels = labels
# output_strings = ["Question", "Response", "Question+Response"]
output_strings = ["Response"]; mia_values = [mia_values]
for strings, values in zip(output_strings, mia_values):
mask = np.isfinite(values)
if not mask.all():
print(f"MIA: {mia} {strings} | Removing {(~mask).sum()} non-finite scores out of {len(values)}")
labels_clean = labels[mask]
values_clean = values[mask]
# Skip ROC computation if no valid data remains
if len(values_clean) == 0:
print(f"MIA: {mia} {strings} - Skipping: no finite values remaining")
continue
fpr_list, tpr_list, thresholds = roc_curve(labels_clean, values_clean)
fpr95 = fpr_list[np.where(tpr_list >= 0.95)[0][0]]
tpr05 = tpr_list[np.where(fpr_list <= 0.05)[0][-1]]
auroc = auc(fpr_list, tpr_list)
if "Response" == strings: # and mia == "LOSS":
print("MIA: {} | {}: AUROC {:.2f}%, FPR95: {:.2f}%, TPR05: {:.2f}%".format(
mia, strings, auroc*100, fpr95*100, tpr05*100
))
# labels_clean = labels_clean.astype(bool)
# members = values_clean[labels_clean]
# non_members = values_clean[~labels_clean]
# if sum(1 for x in values_clean if x > 0)>0:
# print("Members:", members)
# print("Non-members:", non_members)
# print("Members avg log prob: ", np.mean(members)," +-", np.std(members), " | Non-members avg log prob: ", np.mean(non_members)," +-", np.std(non_members),)
def prepare_sharding_records_for_mia(records, args):
number_samples = math.ceil(len(records)/args.global_size)
start_shard_index = (args.sharding-1)*number_samples
end_shard_index = args.sharding*number_samples
return records[start_shard_index:end_shard_index]
def remove_existing_npz(file_dir, prefix_patterns):
for pat in prefix_patterns:
for p in Path(file_dir).glob(pat):
if p.is_file(): # extra safety
p.unlink(missing_ok=True)
from collections import defaultdict
def group_records_by_prompt(records):
"""Group records by prompt, consolidating responses from duplicate questions"""
grouped = defaultdict(lambda: {
'prompt': None,
'prompt_token_ids': None,
'answer': None,
'membership': None,
'response': [],
'token_ids': []
})
for record in records:
prompt = record['prompt']
# Set metadata from first occurrence of this prompt
if not grouped[prompt]['prompt']:
grouped[prompt]['prompt'] = record['prompt']
grouped[prompt]['prompt_token_ids'] = record['prompt_token_ids']
grouped[prompt]['answer'] = record['answer']
grouped[prompt]['membership'] = record['membership']
# Add all responses for this prompt
if isinstance(record['response'], list):
grouped[prompt]['response'].extend(record['response'])
grouped[prompt]['token_ids'].extend(record['token_ids'])
else:
grouped[prompt]['response'].append(record['response'])
grouped[prompt]['token_ids'].append(record['token_ids'])
return list(grouped.values())
def levenshtein_distance(str1, str2):
if len(str1) > len(str2):
str1, str2 = str2, str1
distances = range(len(str1) + 1)
for index2, char2 in enumerate(str2):
new_distances = [index2 + 1]
for index1, char1 in enumerate(str1):
if char1 == char2:
new_distances.append(distances[index1])
else:
new_distances.append(1 + min((distances[index1], distances[index1 + 1], new_distances[-1])))
distances = new_distances
return distances[-1]
def strip_code(sample):
if sample is None:
return ""
return sample.strip().split('\n\n\n')[0] if '\n\n\n' in sample else sample.strip().split('```')[0]
def truncate_prompt(sample, method_name):
if f'def {method_name}(' in sample:
output = sample.replace("'''", '"""')
output = output[output.find('def '+method_name):]
output = output[output.find('"""')+3:]
output = output[output.find('"""\n')+4:] if '"""\n' in output else output[output.find('"""')+3:]
else:
output = sample
return output
def calculate_ratio(numbers, alpha=1):
count = sum(1 for num in numbers if num <= alpha)
total = len(numbers)
ratio = count / total if total > 0 else 0
return ratio
def tokenize_code(sample, tokenizer, length):
return tokenizer.encode(sample)[:length] if length else tokenizer.encode(sample)
def get_edit_distance_distribution_star(samples, gready_sample, tokenizer, length = 100):
gready_sample = strip_code(gready_sample)
gs = tokenize_code(gready_sample, tokenizer, length)
num = []
max_length = len(gs)
for sample in samples:
sample = strip_code(sample)
s = tokenize_code(sample, tokenizer, length)
num.append(levenshtein_distance(gs, s))
max_length = max(max_length, len(s))
return num, max_length
def cal_rouge(rouge, origin_problem, prompt, completion):
target = origin_problem[len(prompt):].strip()
prediction = completion.strip()
prediction = " ".join(prediction.split()[:len(target.split())])
rouge_result = rouge.compute(predictions=[prediction], references=[target], use_stemmer=True)
exact_match = bool(round(rouge_result["rougeL"], 4) == 1.0)
return rouge_result, exact_match
def extract_user_instruction(prompt, model_name):
if "deepseek" in model_name.lower():
return prompt.split("<|User|>")[-1].split("<|Assistant|><think>\n")[0]
elif "qwen" in model_name.lower():
return prompt.split("<|im_start|>user\n")[-1].split("<|im_end|>\n")[0]
elif "llama" in model_name.lower():
return prompt.split("<|start_header_id|>user<|end_header_id|>\n\n")[-1].split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")[0]
def prepare_perturbed_prompt_response(neighbor, model_name, response):
if "deepseek" in model_name.lower():
return ["<|begin▁of▁sentence|><|User|>"+neighbor+"<|Assistant|><think>\n"+response]
elif "qwen" in model_name.lower():
return ["<|im_start|>user\n"+neighbor+"<|im_end|>\n<|im_start|>assistant\n"+response]
elif "llama" in model_name.lower():
return ["<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"+neighbor+"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+response]
def prepare_perturbed_prompt_token_ids(tokenizer, neighbor, model_name):
if "deepseek" in model_name.lower():
return tokenizer("<|begin▁of▁sentence|><|User|>"+neighbor+"<|Assistant|><think>\n")["input_ids"]
elif "qwen" in model_name.lower():
return tokenizer("<|im_start|>user\n"+neighbor+"<|im_end|>\n<|im_start|>assistant\n")["input_ids"]
elif "llama" in model_name.lower():
return tokenizer("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"+neighbor+"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")["input_ids"]
def prepare_reference_model_prompt(ref_model_name, model_name, prompt, response):
user_instruction = extract_user_instruction(prompt, model_name)
if ref_model_name == "bespokelabs/Bespoke-Stratos-7B":
ref_prompt = "<|im_start|>user\n"+user_instruction+"<|im_end|>\n<|im_start|>assistant\n"
return ref_prompt
else:
NotImplementedError