-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_functions.py
More file actions
134 lines (102 loc) · 3.81 KB
/
dataset_functions.py
File metadata and controls
134 lines (102 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import albumentations as A
import randomness
from hypercube_set import *
from imblearn.over_sampling import *
from imblearn.under_sampling import *
from keras.utils import to_categorical
from randomness import *
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
def augment_chunks(patch, label):
"""
Augment patches by adding variations and random changes.
"""
np.random.seed(random_seed)
augmented_chunks = []
augmented_chunk_labels = []
for i in range(0, len(patch)):
new_images = augment_chunk(patch[i])
augmented_chunks.extend(new_images)
augmented_chunk_labels.extend([label[i] for _ in new_images])
return shuffle(np.asarray(augmented_chunks), np.asarray(augmented_chunk_labels))
def augment_chunk(img):
"""
Augment a single patch by rotating it.
"""
return img, np.fliplr(img), np.flipud(img), np.rot90(img, k=1, axes=(0, 1)), np.rot90(img, k=3, axes=(0, 1))
def balance_classes(patch, label, smote=True, clustering=True, reduce=False, strategy='not minority'):
"""
Balance the classes either by downsampling or upsampling.
"""
if reduce:
if clustering:
sm = ClusterCentroids(sampling_strategy=strategy, random_state=random_seed)
else:
sm = RandomUnderSampler(sampling_strategy=strategy, random_state=random_seed)
else:
if smote:
sm = SMOTE(sampling_strategy=strategy, random_state=random_seed)
else:
sm = RandomOverSampler(sampling_strategy=strategy, random_state=random_seed)
sample_shape = patch[0].shape
shape_length = 1
for i in range(0, len(sample_shape)):
shape_length *= sample_shape[i]
reshaped_chunk = np.reshape(patch, (len(patch), shape_length))
x_balanced, y_balanced = sm.fit_resample(reshaped_chunk, label)
reshaped_chunk = np.reshape(x_balanced, (len(x_balanced),) + sample_shape)
return shuffle(reshaped_chunk, y_balanced), np.delete(patch, sm.sample_indices_, axis=0), \
np.delete(label, sm.sample_indices_, axis=0)
def embed_manifold(patch, tsne=False, num_components=2):
"""
Embed the data into a manifold.
"""
if tsne:
reducer = TSNE(random_state=random_seed, n_components=num_components)
else:
reducer = umap.UMAP(random_state=random_seed, n_components=num_components)
embedding = reducer.fit_transform(patch)
return embedding
def get_center(data):
"""
Get the centered pixel of train and test data.
"""
shape = data[0].shape
center = np.array([shape[0] // 2, shape[1] // 2])
return data[:, center[0], center[1], :]
def hot_encode(label):
"""
Hot encode the labels.
"""
return to_categorical(label)
def reduce_labels_center(label):
"""
Reduce the labels to the center of the patch.
"""
new_labels = np.zeros((label.shape[0],), np.int32)
patch_shape = label[0].shape
patch_center = np.array([patch_shape[0] // 2, patch_shape[1] // 2])
for i in range(0, len(label)):
new_labels[i] = label[i][patch_center[0], patch_center[1]]
return new_labels
def remove_labels(patches, labels, r_labels=[]):
"""
Remove specific labels from the data.
"""
for label in r_labels:
indices = np.where(labels != label)
patches = patches[indices]
labels = labels[indices]
return patches, labels
def shuffle(patch, label):
"""
Shuffle the data.
"""
perm = np.random.RandomState(seed=random_seed).permutation(len(patch))
return patch[perm], label[perm]
def split_train_test(patch, labels, test_size=0.7, random_seed=42):
"""
Split data into training and testing sets.
"""
randomness.set_seed(random_seed)
return train_test_split(patch, labels, test_size=test_size, shuffle=True, random_state=random_seed)