-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
113 lines (86 loc) · 3.37 KB
/
main.py
File metadata and controls
113 lines (86 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import copy
import torch
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import numpy as np
from tqdm import tqdm
if torch.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(device)
# Define transformations for the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Load the dataset
dataset = datasets.ImageFolder(root='../../Downloads/orc_vs_elf', transform=transform)
print(dataset.classes)
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
# Example of iterating through the DataLoader
# for images, labels in dataloader:
# print(images.shape, labels)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(512, len(dataset.classes))
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# Define the split sizes
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
def train_model(mdl, train_loader, validation_loader, criterion, optimizer, n_epochs, print_=True):
loss_list = []
accuracy_list = []
correct = 0
# global:val_set
n_test = len(test_dataset)
accuracy_best = 0
best_model_wts = copy.deepcopy(mdl.state_dict())
# Loop through epochs
# Loop through the data in loader
print("The first epoch should take several minutes")
for epoch in tqdm(range(n_epochs)):
loss_sublist = []
# Loop through the data in loader
for x, y in train_loader:
x, y = x.to(device), y.to(device)
model.train()
z = model(x)
loss = criterion(z, y)
loss_sublist.append(loss.data.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("epoch {} done".format(epoch))
scheduler.step()
loss_list.append(np.mean(loss_sublist))
correct = 0
for x_test, y_test in validation_loader:
x_test, y_test = x_test.to(device), y_test.to(device)
model.eval()
z = model(x_test)
_, yhat = torch.max(z.data, 1)
correct += (yhat == y_test).sum().item()
accuracy = correct / n_test
accuracy_list.append(accuracy)
if accuracy > accuracy_best:
accuracy_best = accuracy
best_model_wts = copy.deepcopy(model.state_dict())
if print_:
print('learning rate', optimizer.param_groups[0]['lr'])
print("The validaion Cost for each epoch " + str(epoch + 1) + ": " + str(np.mean(loss_sublist)))
print("The validation accuracy for epoch " + str(epoch + 1) + ": " + str(accuracy))
model.load_state_dict(best_model_wts)
return accuracy_list, loss_list, mdl
train_model(model, train_loader, test_loader, criterion, optimizer, 10)
torch.save(model.state_dict(), 'model-elf.pth')