-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_data.py
More file actions
72 lines (59 loc) · 2.43 KB
/
utils_data.py
File metadata and controls
72 lines (59 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
import glob
from torchvision import transforms
from torchvision.transforms import functional as TF
import random
def load_image(filename, is_mask=False):
if is_mask:
image = Image.open(filename).convert('L')
else:
image = Image.open(filename).convert('RGB')
return image
class Dataset_Seg(Dataset):
def __init__(self, data_path, partition, transform=None, augment=False):
self.data_path = Path(data_path)
self.partition = partition
self.augment = augment
self.transform = transform
self.images_files = sorted(glob.glob(f"{self.data_path/'images'}/*.png"))
self.masks_files = sorted(glob.glob(f"{self.data_path/'masks'}/*.png"))
# Partition split
total = len(self.images_files)
if partition == 'train':
self.images_files = self.images_files[:int(0.8 * total)]
self.masks_files = self.masks_files[:int(0.8 * total)]
elif partition == 'val':
self.images_files = self.images_files[int(0.8 * total):int(0.9 * total)]
self.masks_files = self.masks_files[int(0.8 * total):int(0.9 * total)]
else:
self.images_files = self.images_files[int(0.9 * total):]
self.masks_files = self.masks_files[int(0.9 * total):]
def __len__(self):
return len(self.images_files)
def __getitem__(self, idx):
image = load_image(self.images_files[idx])
mask = load_image(self.masks_files[idx], is_mask=True)
# Resize both
image = image.resize((256, 256), resample=Image.BICUBIC)
mask = mask.resize((256, 256), resample=Image.NEAREST)
# Apply joint augmentation
if self.augment:
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
if random.random() > 0.5:
angle = random.uniform(-15, 15)
image = TF.rotate(image, angle)
mask = TF.rotate(mask, angle)
mask = np.asarray(mask) / 255
image = np.asarray(image) / 255
image_tensor = torch.tensor(image).permute(2, 0, 1)
mask_tensor = torch.tensor(mask).unsqueeze(0)
return image_tensor, mask_tensor