-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
177 lines (140 loc) · 6.1 KB
/
train.py
File metadata and controls
177 lines (140 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#Contains the actual cnn model and training process
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import classification_report
from skorch.callbacks import EarlyStopping, Checkpoint
from skorch import NeuralNetClassifier
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.datasets
import numpy as np
import torch.optim as optim
from cnnModel import CNN
from cnnModel import CNNV1
from cnnModel import CNNV2
#data directory for local computer for dataset (Replace with your own path)
dataPath = "Your Path to Dataset" #main
#path to save the model that is being train (Replace with own path to save the model)
#modelPath = "Your Path To Save Model" #main
#modelPath = "Your Path To Save Model 2" #V1
#modelPath = "Your To Save Model 3" #V2
#trying to mitigate bias
modelPath = "Your PAth to Bias Model" #main no bias?"
"""
Below is the training process for the CNN model.
Along with the showing of the loss for each epoch(iteration) for the model being trained.
Saving the model after the training been done and
Showing the accuracy of the model on the test dataset.
"""
def train(model, trainLoader, valLoader, criterion, optimizer, device, num_epochs, modelPath = modelPath, patience = 5):
#Early stopping variables
bestValLoss = float('inf')
patience = 5
noImprovementCount = 0
#Best fit variables
#best-fit save path
bestModel = modelPath + "_best.pth"
#training model
#training loop
for epoch in range(num_epochs):
#training phase
model.train()
trainCorrect = 0
trainTotal = 0
trainLoss = 0.0
for i, (images,labels) in enumerate(trainLoader):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
#train loss
trainLoss += loss.item() * images.size(0)
#train accuarcy
_, trainPredicted = torch.max(outputs.data, 1)
trainTotal += labels.size(0)
trainCorrect += (trainPredicted == labels).sum().item()
#training accuracy and loss for each epoch
trainAccuracy = trainCorrect/trainTotal
trainLoss /= len(trainLoader.dataset)
#validating train model and save best fit model
model.eval()
valLoss = 0.0
valCorrect = 0
valTotal = 0
for images, labels in valLoader:
with torch.no_grad():
outputs = model(images)
loss = criterion(outputs, labels)
#val loss
valLoss += loss.item() * images.size(0)
#val accuracy
_, predicted = torch.max(outputs.data, 1)
valTotal += labels.size(0)
valCorrect += (predicted == labels).sum().item()
#validation results for each epoch
valLoss /= len(valLoader.dataset)
valAccuracy = valCorrect/valTotal
#print results for each epoch
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {trainLoss:.4f}, Validation Loss: {valLoss:.4f}, "
f"Training Accuracy: {trainAccuracy * 100:.2f}%, Validation Accuracy: {valAccuracy*100:.2f}%")
#best fit checking
if valLoss < bestValLoss:
bestValLoss = valLoss
#save best-fit model
torch.save(model.state_dict(), bestModel)
noImprovementCount = 0
else:
noImprovementCount += 1
#Early stopping
if noImprovementCount > patience:
print("*** Early Stopping Happened! ****")
break
#saving the final model
torch.save(model.state_dict(), modelPath)
print("\n ++++++ Training Complete!! +++++ ")
if __name__ == "__main__":
#Set random seed to be the same each time to help with reproducability
torch.manual_seed(0)
np.random.seed(0)
#Seting up the pretraining-process
#Hyperparameters
num_epochs = 10 #minimum of 10 iteration
num_classes = 4 #total number of classes
learningRate = 0.001 #learnin rate for the model
#transform
transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), #Images are in grayscaled
transforms.Resize((48,48)),#image size 48 x 48
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
#getting dataset for model to train on
dataset = torchvision.datasets.ImageFolder(dataPath, transform=transform)
#randomely spliting the datset into training and testing (70% for training, 20 % for testing and 10% for validation)
m = len(dataset)
trainSize = int(0.7*m)
testSize = int(0.1*m)
valSize = m - trainSize - testSize
trainData, testData, valData = random_split(dataset, [trainSize,testSize,valSize])
#Data Loader
#allow random order for loading data (shuffle = true) and use 2 subprocess to load data
trainLoader = DataLoader(trainData, batch_size=32, shuffle=True, num_workers=2)
testLoader = DataLoader(testData, batch_size=32, shuffle=False, num_workers=2)
valLoader = DataLoader(valData, batch_size=32, shuffle=False, num_workers=2)
#checking if user has cuda to be able to use GPU instead of CPU for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#yTrain = np.array ([y for x, y,in iter (trainData)])
classes = ("angry", "neutral", "engaged", "surpise") #classes for classification
#creating an instance of the cnn model created from above
#model = CNNV1() #train with variant 1
#model = CNNV2() #train with variant 2
model = CNN() #train with main
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learningRate)
train(model, trainLoader, valLoader, criterion, optimizer, device, num_epochs)