-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathestimate.py
More file actions
107 lines (88 loc) · 3.78 KB
/
estimate.py
File metadata and controls
107 lines (88 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Module for estimating parcellations and functional profiles
Author: Bassel Arafat
"""
import torch as pt
import numpy as np
import OptimalBattery.evaluate as ev
import OptimalBattery.util as ut
######################### Parcellation U Estimation #########################
def estimate_Us(Y, V, method='cos_angle', alpha=1e-3, hard=False):
"""
Estimate U_hat using different projection methods: 'correlation', 'ols', or 'ridge', 'VMF
Args:
Y (torch.Tensor): fMRI data of shape (n_subjects, n_tasks, n_voxels)
V (torch.Tensor): Functional profile of shape (n_tasks, n_parcels)
method (str): Choice of {'correlation', 'ols', 'ridge'}
alpha (float): Regularization for ridge (ignored unless method == 'ridge')
hard (bool): If True, returns one-hot assignment for each voxel to one parcel
Returns:
U_hats (torch.Tensor):
If hard=False, shape = (n_subjects, n_parcels, n_voxels) with continuous weights
If hard=True, shape = (n_subjects, n_parcels, n_voxels) with 0/1 assignments
"""
# check if Y is 2 dimensional
if len(Y.shape) == 2:
Y = Y.unsqueeze(0)
# 1) Compute weights depending on method
if method == 'cos_angle':
# correlation ~ (V^T @ Y)
U_hats = V.T @ Y
elif method == 'ols':
# OLS: (V^T V)^(-1) V^T @ Y
A = V.T @ V
A_inv = pt.linalg.inv(A)
U_hats = A_inv @ (V.T @ Y)
elif method == 'ridge':
# Ridge: (V^T V + alpha*I)^(-1) V^T @ Y
A = V.T @ V
alpha_eye = pt.eye(A.shape[0], device=A.device) * alpha
A_inv = pt.linalg.inv(A + alpha_eye)
U_hats = A_inv @ (V.T @ Y)
else:
raise ValueError(f"Invalid method")
# 2) Return continuous or hard assignments
if hard:
U_hats += 1e-10 * pt.arange(U_hats.shape[-1], device=U_hats.device)
max_indices = pt.argmax(U_hats, dim=1) # (n_subjects, n_voxels)
U_hard = pt.zeros_like(U_hats)
U_hard.scatter_(1, max_indices.unsqueeze(1), 1)
U_hats = U_hard
return U_hats
######################### V estimation #####################################
def estimate_Vs(data, parcellation, ROI_mask = None):
"""
Compute Vs by averaging data within the parcels from a given parcellation. can be restricted to a ROI + parcellation overlap
Parameters:
data (torch.Tensor): fMRI data of shape (n_subjects, n_tasks, n_voxels).
parcellation (torch.Tensor): Parcellation indices of shape (n_voxels,).
ROI_mask (torch.Tensor): Binary mask of shape (n_voxels,) indicating the region of interest. (optional)
Returns:
Vs (torch.Tensor): Averaged values for each condition within selected parcels, shape (n_tasks, n_selected_parcels).
"""
# Average across subjects
avg_data = data.mean(dim=0)
# Get the values of the unique parcels in the parcellation
parcel_list = pt.unique(parcellation)
Vs = []
for p in parcel_list:
if ROI_mask is None:
# Get the voxels that are in the parcel
overlap_indices = pt.where(parcellation == p)[0]
else:
# Get the voxels that are in both the parcel and the ROI
overlap_indices = pt.where((parcellation == p) & (ROI_mask>0))[0]
# if there are no voxels in the parcel that are in the ROI, skip
if len(overlap_indices) == 0:
continue
# Get the data for the voxels in the parcel that are in the ROI
parcel_data = avg_data[:, overlap_indices]
parcel_data = parcel_data.mean(dim=1)
Vs.append(parcel_data)
Vs = pt.stack(Vs, dim=1)
return Vs
if __name__ == "__main__":
# Vs = pt.rand(29, 5)
# data = pt.rand(24,29,100)
# Us = estimate_Us(data, Vs,method = 'correlation', hard=True)
pass