-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
executable file
·203 lines (171 loc) · 6.92 KB
/
evaluate.py
File metadata and controls
executable file
·203 lines (171 loc) · 6.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import re
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import parse_arguments, start_experiment, save_results, build_model, build_dataset
import os
import json
def extract_prediction(text):
"""Extract prediction number from model output."""
# Try to find "answer: X" pattern
answer_match = re.search(r'answer:\s*(\d+)', text.lower())
if answer_match:
return int(answer_match.group(1))
# If no "answer: X" pattern, try to find any number
number_match = re.search(r'\d+', text)
if number_match:
return int(number_match.group(0))
# Default to -1 if no number found
return -1
def evaluate_skill(model, dataset, collator, device, skill, args):
model.eval()
all_outputs = []
all_metadata = []
# Create DataLoader
dataloader = DataLoader(
dataset,
batch_size=args.eval_batch_size,
shuffle=False,
num_workers=10,
collate_fn=collator
)
# First pass: collect all model outputs
for inputs, metadata in tqdm(dataloader, desc=f"Evaluating {skill}"):
# Generate predictions
with torch.no_grad():
if 'molmo' in args.model.lower():
from transformers import GenerationConfig
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
outputs = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=collator.processor.tokenizer)
all_outputs.extend(outputs.cpu())
elif 'minicpm' in args.model.lower():
res, context, _ = model.chat(
**inputs,
max_new_tokens=512,
tokenizer=collator.processor,
disable_compile=True
)
all_outputs.append(res)
else:
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=512,
disable_compile='paligemma' in args.model.lower() # https://github.com/huggingface/transformers/issues/36544
)
all_outputs.extend(outputs.cpu())
all_metadata.append(metadata)
if 'minicpm' in args.model.lower():
predictions = all_outputs
elif 'molmo' in args.model.lower():
print(f"Decoding {skill} outputs...")
predictions = collator.processor.tokenizer.batch_decode(all_outputs, skip_special_tokens=True)
else:
# Batch decode all outputs
print(f"Decoding {skill} outputs...")
predictions = collator.processor.batch_decode(all_outputs, skip_special_tokens=True)
51
# Process all predictions at once
print(f"Processing {skill} predictions...")
numerical_preds = [extract_prediction(pred) for pred in predictions]
# Combine metadata
combined_metadata = {
'sample_ids': sum((m['sample_ids'] for m in all_metadata), []),
'labels': sum((m['labels'] for m in all_metadata), [])
}
# Create results
results = []
if args.split == 'val':
total_correct = 0
total_samples = 0
for pred_text, pred_num, sample_id, solution_idx in zip(
predictions, numerical_preds,
combined_metadata['sample_ids'],
combined_metadata['labels']
):
result = {
'sample_id': sample_id,
'full_prediction': pred_text,
'prediction': pred_num,
'skill': skill,
}
if args.split == 'val':
result['solution_index'] = solution_idx
correct = pred_num == solution_idx
result['correct'] = correct
total_correct += correct
total_samples += 1
results.append(result)
# Add summary result for validation set
if args.split == 'val':
summary = {
'sample_id': 'summary',
'accuracy': total_correct / total_samples,
'total_samples': total_samples,
}
results.insert(0, summary)
return results
def main():
args = parse_arguments()
experiment_name = start_experiment(args)
# Build model and move to device
with torch.no_grad():
model, collator = build_model(args)
device = next(model.parameters()).device
# Track overall results
overall_results = {
skill: {"num_examples": 0, "correct": 0} for skill in args.skill
}
# Evaluate each skill
for skill in args.skill:
# Load dataset
dataset = build_dataset(args.dataset, skill, args.split, args.single_image, args)
# Evaluate
results = evaluate_skill(model, dataset, collator, device, skill, args)
# Save results
save_results(results, experiment_name, skill, args)
# Update overall results for validation
if args.split == 'val':
summary = results[0] # First item contains the summary
overall_results[skill]["num_examples"] = summary["total_samples"]
overall_results[skill]["accuracy"] = summary["accuracy"]
# Log to wandb if enabled
if args.wandb:
import wandb as wb
wb.log({
f"{skill}/accuracy": summary["accuracy"],
f"{skill}/samples": summary["total_samples"]
})
# Calculate and log overall results for validation
if args.split == 'val':
total_examples = sum(data["num_examples"] for data in overall_results.values())
overall_accuracy = sum(
data["num_examples"] * data["accuracy"]
for data in overall_results.values()
) / total_examples
# Save overall results
overall_summary = {
"overall_accuracy": overall_accuracy,
"total_samples": total_examples,
"skill_results": overall_results
}
with open(os.path.join(args.save_dir, experiment_name, "overall_results.json"), "w") as f:
json.dump(overall_summary, f, indent=2)
# Log overall results to wandb
if args.wandb:
wb.log({
"overall/accuracy": overall_accuracy,
"overall/samples": total_examples
})
# Create a summary table
if wb.run is not None:
data = [[skill, data["accuracy"], data["num_examples"]]
for skill, data in overall_results.items()]
data.append(["Overall", overall_accuracy, total_examples])
table = wb.Table(data=data, columns=["Skill", "Accuracy", "Samples"])
wb.log({"results_table": table})
if __name__ == "__main__":
main()