-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathR2.py
More file actions
75 lines (63 loc) · 2.58 KB
/
R2.py
File metadata and controls
75 lines (63 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from sparc.post_analysis import HDF5AnalysisResultsDataset
from sklearn.metrics import r2_score
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
batch_size_used_for_generation = 256
def compute_r2(analysis_results):
print("Loading data:")
streams = analysis_results.streams
raw_data = {stream: analysis_results.get_all_features_for_stream(stream, 'raw')
for stream in streams}
recon_data = {stream: analysis_results.get_all_features_for_stream(stream, 'recon')
for stream in streams}
# Compute R² scores vectorized
r2_scores = []
for target in tqdm(streams):
for source in streams:
if target == source:
r2 = r2_score(raw_data[target], recon_data[target])
else:
cross_recon = analysis_results.get_all_cross_reconstruction_features(f'{target}_from_{source}')
r2 = r2_score(raw_data[target], cross_recon)
r2_scores.append(r2)
# Create DataFrame
r2_df = pd.DataFrame(np.array(r2_scores).reshape(len(streams), len(streams)), index=streams, columns=streams)
r2_df.index.name = 'Original Target'
r2_df.columns.name = 'Reconstructed From (Source)'
return r2_df
def main():
scenarios = [
# 'sparc_open_siglip',
# 'sparc_open_clip',
# 'sparc_open_image',
# 'sparc_open_image_only',
# 'sparc_open_text',
# 'sparc_open_text_only',
'sparc_open_all'
]
variants = ['global_cross', 'local_cross']
for scenario in scenarios:
print(f"\n{'='*60}")
print(f"Scenario: {scenario}")
print('='*60)
for variant in variants:
folder_path = f'results/{scenario}_{variant}'
file_path = f'{folder_path}/analysis_cache_val.h5'
if os.path.exists(file_path):
print(f"\n{variant.replace('_', ' ').title()}:")
print('-'*40)
analysis_results = HDF5AnalysisResultsDataset(
file_path, batch_size_used_for_generation
)
r2_df = compute_r2(analysis_results)
print(r2_df.to_string())
# Save R² dataframe to CSV file in the folder
csv_output_path = f'{folder_path}/{scenario}_{variant}_r2_scores.csv'
r2_df.to_csv(csv_output_path)
print(f"R² scores saved to: {csv_output_path}")
else:
print(f"\nFile not found: {file_path}")
if __name__ == "__main__":
main()