-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdatasets.py
More file actions
108 lines (94 loc) · 3.82 KB
/
datasets.py
File metadata and controls
108 lines (94 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from pathlib import Path
from PIL import Image
import numpy as np
import os
import pickle
import PIL
import torch
from torch.utils.data import Dataset
from config import Config
config = Config()
def get_slide_path(train_path, parent_path):
'''
args:
train_path: a list of WSI names
parent_path: stem path
'''
patch_path = []
patch_label = []
patch_position = []
parent_slide = []
for item in train_path:
slide_folder = parent_path.joinpath(item)
for class_name in os.listdir(slide_folder):
class_folder = slide_folder.joinpath(class_name)
for patch in os.listdir(class_folder):
#^TODO: PLEASE USE PATHLIB FUNCTIONS
patch_path.append(class_folder.joinpath(patch))
patch_label.append(class_name)
patch_position.append([int(i) for i in patch[:-4].split('_')])
parent_slide.append(item)
return patch_path,patch_label,patch_position,parent_slide
def pil_loader(path):
'''
Load image
Args:
path: image path
'''
with open(path, 'rb') as f:
img = Image.open(f)
#^TODO: YOU CAN DO Image.open(path)
return img.convert('RGB')
class SlideData(Dataset):
"""
Create a dataset to load patch-level images
"""
def __init__(self,path, overall_info, transforms=None, train='train', exclude=None):
self.transforms = transforms
self.path = path
overall_info = pickle.load(open(overall_info,'rb'))
if exclude is not None:
overall_info[1] = {k:v for k, v in overall_info[1].items() if v!=exclude}
overall_info[0] = {k:v for k, v in overall_info[0].items() if k in overall_info[1].keys()}
overall_info[2] = {k:v for k, v in overall_info[2].items() if k in overall_info[1].keys()}
self.overall_info = overall_info #[id2slide,id2label,id2split]
self.slide2id = {v: k for k, v in overall_info[0].items()}
self.classes, self.class_to_idx = self._find_classes(overall_info[1])
wsi_id = overall_info[0].keys()
train_id = [item for item in wsi_id if overall_info[2][item]=='train']
val_id = [item for item in wsi_id if overall_info[2][item]=='val']
train_plus_val_id = train_id + val_id
test_id = [item for item in wsi_id if overall_info[2][item]=='test']
if train == 'train':
# path, label, position, parent slide
self.plpp = get_slide_path(
[overall_info[0][item] for item in train_id],path)
self.current_id = train_id
elif train == 'val':
self.plpp = get_slide_path(
[overall_info[0][item] for item in val_id],path)
self.current_id = val_id
elif train == 'trainplusval':
self.plpp = get_slide_path(
[overall_info[0][item] for item in train_plus_val_id],path)
self.current_id = train_plus_val_id
else:
self.plpp = get_slide_path(
[overall_info[0][item] for item in test_id],path)
self.current_id = test_id
def _find_classes(self, id2label):
classes = list(set(id2label.values()))
classes.sort()
class_to_idx = {cls_name:i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index):
sample = pil_loader(self.plpp[0][index])
if self.transforms is not None:
sample = self.transforms(sample)
target = self.class_to_idx[self.plpp[1][index]]
return sample, target, self.plpp[2][index], self.slide2id[self.plpp[3][index]]
def __len__(self):
return len(self.plpp[0])
def pick_WSI(self, wsi_id):
self.plpp = get_slide_path([self.overall_info[0][item] for item in [wsi_id]],self.path)
self.current_id = wsi_id