-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
86 lines (67 loc) · 2.89 KB
/
dataset.py
File metadata and controls
86 lines (67 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
import numpy as np
import pickle
import os
from PIL import Image
from utils import cutmix_data
import utils
from matplotlib import pyplot as plt
def oversample_class(data, labels, target_class, factor=2):
""" Oversample a specific class in the dataset by a given factor."""
class_indices = np.where(labels == target_class)[0]
extra_data = data[class_indices]
extra_labels = labels[class_indices]
for _ in range(factor - 1):
data = np.concatenate((data, extra_data), axis=0)
labels = np.concatenate((labels, extra_labels), axis=0)
return data, labels
def load_data(batch_size = 64):
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
# download CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
# dataloader
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
return trainloader, testloader
def show_image(img_tensor, label):
# labels
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img = img_tensor.permute(1, 2, 0).numpy()
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2470, 0.2435, 0.2616])
img = (img * std + mean) * 255.0
img = img.astype(np.uint8)
# show img
plt.imshow(img)
plt.title(f"Label: {classes[label]}")
plt.axis('off') # Hide axes for better image display
plt.show()
if __name__ == "__main__":
trainloader, testloader = load_data()
print(f"Number of training batches: {len(trainloader)}")
print(f"Number of testing batches: {len(testloader)}")
print(f"Number of training samples: {len(trainloader.dataset)}")
print(f"Number of testing samples: {len(testloader.dataset)}")
print(np.max(trainloader.dataset[0][0].numpy()))
print(np.min(trainloader.dataset[0][0].numpy()))
# show 10 images from the training set
for i in range(10):
image, label = trainloader.dataset[i]
show_image(image, label)
for i in range(128, 138):
image, label = trainloader.dataset[i]
show_image(image, label)