-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcollisionCheckDataset.py
More file actions
39 lines (33 loc) · 1.09 KB
/
collisionCheckDataset.py
File metadata and controls
39 lines (33 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from torch.utils.data import Dataset
from utils import getData
import numpy as np
from trajAugmentations import TrajAugs
from finetuneSWAV import set_params
class CollisionCheckDataset(Dataset):
def __init__(self,args):
self.data=getData(args)
if type(self.data)==list:
self.length=0
self.partitions=[]
for d in self.data:
self.length+=d.__len__()
self.partitions.append(d.__len__())
else:
self.length=self.data.__len__()
self.augs=TrajAugs()
def __len__(self):
return self.length
def __getitem__(self, item):
breakpoint()
if type(self.data) == list:
inds,=np.where(np.cumsum(self.partitions)>=item)
batch = self.data[inds[0]].__getitem__(item-np.sum(self.partitions[:inds[0]])-1)
else:
batch = self.data.__getitem__(item)
if len(batch)==4:
peopleIDs, locs, targ_locs, frame = batch
breakpoint()
args=set_params()
d=CollisionCheckDataset(args)
print(d.__getitem__(2000))