-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataset_class.py
More file actions
56 lines (42 loc) · 1.8 KB
/
Dataset_class.py
File metadata and controls
56 lines (42 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from PIL import Image
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, ConcatDataset
import time
st = time.time()
class BuildingDataset(data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.buildings = sorted(os.listdir(os.path.join(self.root_dir, 'train', 'drone')))
def __len__(self):
return len(self.buildings)
def __getitem__(self, index):
building_name = self.buildings[index]
drone_dir = os.path.join(self.root_dir, 'train', 'drone', building_name)
satellite_dir = os.path.join(self.root_dir, 'train', 'satellite', building_name)
satellite_path = os.path.join(satellite_dir, os.listdir(satellite_dir)[0])
satellite_img = Image.open(satellite_path).convert('RGB')
def drone_imgs():
for drone_img_name in os.listdir(drone_dir):
drone_img_path = os.path.join(drone_dir, drone_img_name)
drone_img = Image.open(drone_img_path).convert('RGB')
yield drone_img
if self.transform:
satellite_img = self.transform(satellite_img)
dataset = [(satellite_img, self.transform(drone_img)) for drone_img in drone_imgs()]
return dataset
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = BuildingDataset(root_dir='University-Release', transform=transform)
import itertools
result_list = list(itertools.chain.from_iterable(dataset))
print(len(result_list))
end = time.time()
print(end - st)