-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtorch_datasets.py
More file actions
41 lines (35 loc) · 1.47 KB
/
torch_datasets.py
File metadata and controls
41 lines (35 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch.utils.data import Dataset
import torch
class MRINEDataset(Dataset):
'''
MRINE Dataset.
s: torch.Tensor, spiking activity, or discrete observations. (num_seq, num_steps, n_s)
y: torch.Tensor, LFP, or continuous signals. (num_seq, num_steps, n_y)
m_s: torch.Tensor, mask tensor for s denoting whether s is available for a time-step or not. (num_seq, num_steps)
m_y: torch.Tensor, mask tensor for y denoting whether y is available for a time-step or not. (num_seq, num_steps)
target: torch.Tensor, target variable to decode from inferred latent factors. (num_seq, num_steps)
'''
def __init__(self, s=None, y=None, m_s=None, m_y=None, target=None):
self.s = s
self.y = y
self.m_s = m_s
self.m_y = m_y
self.target = target
if self.s is None and self.y is not None:
self.s = self.y
self.m_s = self.m_y
elif self.s is not None and self.y is None:
self.y = self.s
self.m_y = self.m_s
if self.target is None:
self.target = torch.ones_like(self.y)
def __len__(self):
return self.s.shape[0]
def __getitem__(self, idx):
data_sample = []
for data in [self.s, self.y, self.m_s, self.m_y, self.target]:
if len(data.shape) == 2:
data_sample.append(data[idx, :])
elif len(data.shape) == 3:
data_sample.append(data[idx, :, :])
return data_sample