forked from chanhee-luke/RoboSpatial-Eval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluation.py
More file actions
345 lines (290 loc) · 13.7 KB
/
evaluation.py
File metadata and controls
345 lines (290 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""
Evaluation logic for RoboSpatial-Home benchmark.
Contains functions for evaluating model responses against ground truth data.
"""
import os
import re
import ast
from tqdm import tqdm
def point_in_polygon(x, y, poly):
"""
Check if the point (x, y) lies within the polygon defined by a list of (x, y) tuples.
Uses the ray-casting algorithm.
"""
n = len(poly)
inside = False
j = n - 1
for i in range(n):
xi, yi = poly[i]
xj, yj = poly[j]
if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi):
inside = not inside
j = i
return inside
def evaluate_answer(ground_truth, generated_answer):
"""
Evaluates if the generated answer is correct based on the ground truth.
Returns a tuple of (is_correct, is_binary_answer, parsed_answer, is_parsable).
"""
gen_answer = generated_answer.strip().lower()
gt_lower = ground_truth.strip().lower()
# Check if this is a binary yes/no question
if gt_lower in ["yes", "no"]:
is_binary = True
is_gt_yes = (gt_lower == "yes")
# Binary answers are always considered parsable if they contain text
is_parsable = len(gen_answer) > 0
if is_gt_yes:
correct = gen_answer.startswith("yes")
else:
correct = gen_answer.startswith("no")
return correct, is_binary, gen_answer, is_parsable
else:
# Numeric evaluation: ground_truth is a list of points defining a polygon
is_binary = False
parsed_answer = None
is_parsable = False # Default to not parsable until we successfully parse
try:
gt_polygon = ast.literal_eval(ground_truth)
if not isinstance(gt_polygon, list) or len(gt_polygon) < 3:
return False, is_binary, parsed_answer, is_parsable
# Extract the first coordinate pair using regex
# Look for patterns like (0.1,0.2) or (0.1, 0.2) or [0.1, 0.2] or [0.1,0.2]
# This approach is more robust than trying to parse the entire list
# Try to match tuple format (x,y) or (x, y)
tuple_match = re.search(r'\(\s*(\d+\.?\d*)\s*,\s*(\d+\.?\d*)\s*\)', generated_answer)
if tuple_match:
try:
x = float(tuple_match.group(1))
y = float(tuple_match.group(2))
parsed_answer = (x, y)
is_parsable = True
correct = point_in_polygon(x, y, gt_polygon)
return correct, is_binary, parsed_answer, is_parsable
except (ValueError, TypeError):
pass # Continue to other formats if float conversion fails
# Try to match list format [x,y] or [x, y]
list_match = re.search(r'\[\s*(\d+\.?\d*)\s*,\s*(\d+\.?\d*)\s*\]', generated_answer)
if list_match:
try:
x = float(list_match.group(1))
y = float(list_match.group(2))
parsed_answer = (x, y)
is_parsable = True
correct = point_in_polygon(x, y, gt_polygon)
return correct, is_binary, parsed_answer, is_parsable
except (ValueError, TypeError):
pass # Continue to other formats if float conversion fails
# Fall back to the original approach but with extra safety
try:
# Extract the first list (text between square brackets) from generated_answer
# Use a regex that can handle multi-line content
match = re.search(r'\[(.*?)\]', generated_answer, re.DOTALL)
if match is None:
return False, is_binary, parsed_answer, is_parsable
# Add spaces after commas if not present (to help ast.literal_eval)
list_content = match.group(1)
list_content = re.sub(r',(\S)', r', \1', list_content)
# Try to fix truncated tuples by adding closing parenthesis and brackets if needed
list_content = list_content.strip()
if list_content.endswith(','):
list_content = list_content[:-1]
list_str = '[' + list_content + ']'
# Try to parse the list directly
try:
gen_val = ast.literal_eval(list_str)
except (SyntaxError, ValueError):
# If direct parsing fails, try to extract just the first tuple
tuple_match = re.search(r'\(\s*(\d+\.?\d*)\s*,\s*(\d+\.?\d*)\s*\)', list_content)
if tuple_match:
x = float(tuple_match.group(1))
y = float(tuple_match.group(2))
parsed_answer = (x, y)
is_parsable = True
correct = point_in_polygon(x, y, gt_polygon)
return correct, is_binary, parsed_answer, is_parsable
else:
return False, is_binary, parsed_answer, is_parsable
# Handle different formats for points
if isinstance(gen_val, list):
if len(gen_val) == 0:
return False, is_binary, parsed_answer, is_parsable
# Case 1: The list itself is a point coordinates [x, y]
if len(gen_val) == 2 and all(isinstance(v, (int, float)) for v in gen_val):
gen_point = tuple(gen_val) # Convert [x, y] to (x, y)
# Case 2: The list contains points [(x, y), ...]
elif isinstance(gen_val[0], tuple):
gen_point = gen_val[0]
# Case 3: The list contains coordinate pairs as lists [[x, y], ...]
elif isinstance(gen_val[0], list) and len(gen_val[0]) == 2:
gen_point = tuple(gen_val[0]) # Convert [x, y] to (x, y)
else:
return False, is_binary, parsed_answer, is_parsable
elif isinstance(gen_val, tuple):
gen_point = gen_val
else:
return False, is_binary, parsed_answer, is_parsable
if not (isinstance(gen_point, tuple) and len(gen_point) == 2):
return False, is_binary, parsed_answer, is_parsable
x, y = float(gen_point[0]), float(gen_point[1])
parsed_answer = (x, y)
is_parsable = True
correct = point_in_polygon(x, y, gt_polygon)
return correct, is_binary, parsed_answer, is_parsable
except Exception:
# If all parsing attempts fail, return False
return False, is_binary, parsed_answer, is_parsable
except Exception as e:
print(f"Error evaluating answer: {e}")
return False, is_binary, parsed_answer, is_parsable
def eval_robospatial_home(json_data, model_name, model_kwargs, data_dir, run_model_fn):
"""
Evaluate RoboSpatial-Home data by running the model on each example.
Args:
json_data: List of data entries to evaluate
model_name: Name of the model being evaluated
model_kwargs: Model-specific arguments (tokenizer, model object, etc.)
data_dir: Root directory containing dataset files and images
run_model_fn: Function to run the model on a single example
Returns:
Dictionary containing evaluation statistics and results
"""
results = []
num_correct = 0
num_total = len(json_data)
illformed_questions = 0
illformed_responses = 0
# Dictionary to keep per-category statistics
category_stats = {}
for entry in tqdm(json_data, desc="Evaluating RoboSpatial-Home"):
# Extract question and ground-truth answer directly from the entry
question = entry.get("question", "")
ground_truth = entry.get("answer", "")
if not question or not ground_truth:
illformed_questions += 1
continue
category = entry.get("category", "unknown")
if category not in category_stats:
category_stats[category] = {"num_correct": 0, "num_total": 0}
category_stats[category]["num_total"] += 1
# Build absolute image path using the img field
image_rel_path = entry.get("img", "")
image_path = os.path.join(data_dir, image_rel_path)
# Run the model
generated_answer = run_model_fn(question, image_path, model_name, model_kwargs)
# Evaluate the answer
correct, is_binary, parsed_answer, is_parsable = evaluate_answer(ground_truth, generated_answer)
# Count illformed responses - now tracks any answer that couldn't be parsed correctly
if not is_parsable:
illformed_responses += 1
if correct:
num_correct += 1
category_stats[category]["num_correct"] += 1
results.append({
"question": question,
"expected_answer": ground_truth,
"generated_answer": generated_answer,
"parsed_answer": str(parsed_answer) if parsed_answer is not None else None,
"correct": correct,
"is_parsable": is_parsable,
"category": category,
"image": image_path
})
accuracy = 100.0 * num_correct / num_total if num_total > 0 else 0.0
return {
"accuracy": accuracy,
"num_correct": num_correct,
"num_total": num_total,
"illformed_questions": illformed_questions,
"illformed_responses": illformed_responses,
"category_stats": category_stats,
"results": results
}
def eval_pregenerated_results(gt_data, results_data, data_dir):
"""
Evaluate pre-generated results against ground truth.
Args:
gt_data: List of ground truth data (from the benchmark)
results_data: List of pre-generated model responses
data_dir: Root directory containing dataset files and images
Returns:
Dictionary of evaluation statistics
"""
results = []
num_correct = 0
num_total = 0 # Will count only entries that can be evaluated
illformed_questions = 0
illformed_responses = 0
unmatched_entries = 0 # New counter for entries without matching results
# Dictionary to keep per-category statistics
category_stats = {}
# Pre-process results_data for more efficient matching
# Build lookup dictionaries for faster matching
results_by_question_and_image = {}
for result_entry in results_data:
question = result_entry.get("question", "")
img_path = result_entry.get("img", "")
if question and img_path:
key = (question, img_path)
results_by_question_and_image[key] = result_entry
# Process ground truth entries
for gt_entry in tqdm(gt_data, desc="Evaluating Pre-generated Results"):
# Extract data from ground truth entry
question = gt_entry.get("question", "")
ground_truth = gt_entry.get("answer", "")
image_rel_path = gt_entry.get("img", "")
if not question or not ground_truth:
illformed_questions += 1
continue
# Increment category stats
category = gt_entry.get("category", "unknown")
if category not in category_stats:
category_stats[category] = {"num_correct": 0, "num_total": 0}
# Try to find a match in the pre-processed results
key = (question, image_rel_path)
matched_result = results_by_question_and_image.get(key)
# If no match found, check if this is from a known source file
if matched_result is None:
# Count as unmatched rather than illformed
unmatched_entries += 1
continue
# Only now do we count this entry toward the total and category
num_total += 1
category_stats[category]["num_total"] += 1
# Extract generated answer
generated_answer = matched_result.get("answer", "")
if not generated_answer:
illformed_responses += 1
continue
# Build absolute image path
image_path = os.path.join(data_dir, image_rel_path)
# Evaluate the answer
correct, is_binary, parsed_answer, is_parsable = evaluate_answer(ground_truth, generated_answer)
# Count illformed responses - now tracks any answer that couldn't be parsed correctly
if not is_parsable:
illformed_responses += 1
if correct:
num_correct += 1
category_stats[category]["num_correct"] += 1
results.append({
"question": question,
"expected_answer": ground_truth,
"generated_answer": generated_answer,
"parsed_answer": str(parsed_answer) if parsed_answer is not None else None,
"correct": correct,
"is_parsable": is_parsable,
"category": category,
"image": image_path
})
# Calculate accuracy
accuracy = 100.0 * num_correct / num_total if num_total > 0 else 0.0
return {
"accuracy": accuracy,
"num_correct": num_correct,
"num_total": num_total,
"illformed_questions": illformed_questions,
"illformed_responses": illformed_responses,
"unmatched_entries": unmatched_entries, # New field to track unmatched entries
"category_stats": category_stats,
"results": results
}