-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGAN_DataLoader2.py
More file actions
92 lines (73 loc) · 2.66 KB
/
GAN_DataLoader2.py
File metadata and controls
92 lines (73 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Want to load:
# - Target image
# - Source image embedding
# = Target image embedding
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
import torchvision
import os
from PIL import Image
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.datasets import ImageFolder
# IMAGE_PATH = ""
# transform = None
# dataset = ImageFolder(IMAGE_PATH, transform)
class GAN_Dataset(data.Dataset):
def __init__(self, dir, transform):
super(GAN_Dataset, self).__init__()
self.dir = dir
self.transform = transform
self.total_imgs = os.listdir(dir)
self.img_paths = []
img_path = self.dir + "/"
img_list = os.listdir(dir)
img_nums = len(img_list)
for i in range(img_nums):
img_name = img_path + img_list[i]
self.img_paths.append(img_name)
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = os.path.join(self.dir, self.total_imgs[idx])
image = Image.open(img_loc).convert('RGB')
tensor_image = self.transform(image)
name = self.img_paths[idx]
print(name)
return tensor_image
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
def dataloader(root_source = "\clean_dataset\train_data",
root_target = "\clean_dataset\train_data",
image_size = 224,
num_channels = 3,
batch_size = 4,
num_workers = 6,
shuffle = True):
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
transforms.Resize((image_size,image_size)),
# transforms.RandomRotation(45),
])
# image_data = GAN_Dataset(dir = root, transform = transform)
# dataset = data.TensorDataset(image_data, image_data)
# dataloader = data.DataLoader(image_data,
# batch_size=batch_size,
# shuffle=True,
# num_workers=num_workers)
# return zip(dataloader, dataloader)
dataloader = data.DataLoader(
ConcatDataset(
GAN_Dataset(dir = root_source, transform = transform),
GAN_Dataset(dir = root_target, transform = transform)
), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader