-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
117 lines (86 loc) · 3.38 KB
/
train.py
File metadata and controls
117 lines (86 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -*- coding: utf-8 -*-
"""
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/12Lx71CDeZWSJ2IlJA4UiVcBjGGx0A5pQ
"""
import torch
# Function to train the network
def train_model(model, train_loader, val_loader, optimizer, cost_function, patience=10, epochs=50):
"""
Trains the model for a given number of epochs.
"""
# define model and move to device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# tracking early stop
best_val_loss = float('inf')
epochs_without_improve = 0
best_state_dict = None
# history of losses and accuracy
history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy':[]}
for epoch in range(1, epochs+1):
model.train()
train_samples = 0.
cumulative_loss = 0.
cumulative_corrects = 0
for inputs, targets in train_loader:
# load data into GPU
inputs = inputs.to(device)
targets = targets.to(device).squeeze()
# forward pass
outputs, _ = model(inputs)
# apply the loss
loss = cost_function(outputs, targets)
# reset the gradient
optimizer.zero_grad()
# backward pass
loss.backward()
# update parameters
optimizer.step()
cumulative_loss += loss.item() * inputs.size(0)
train_samples += inputs.size(0)
preds = torch.argmax(outputs, dim=1)
cumulative_corrects += (preds == targets).sum().item()
epoch_train_accuracy = cumulative_corrects / train_samples
epoch_train_loss = cumulative_loss / train_samples
history['train_loss'].append(epoch_train_loss)
history['train_accuracy'].append(epoch_train_accuracy)
# validation for one epoch
model.eval()
val_loss = 0.
val_samples = 0.
val_corrects = 0
with torch.no_grad():
for inputs, targets in val_loader:
# load data into device
inputs = inputs.to(device)
targets = targets.to(device).squeeze()
outputs, _ = model(inputs)
loss = cost_function(outputs, targets)
val_loss += loss.item() * inputs.size(0)
val_samples += inputs.size(0)
preds = torch.argmax(outputs, dim=1)
val_corrects += (preds == targets).sum().item()
epoch_val_accuracy = val_corrects / val_samples
epoch_val_loss = val_loss / val_samples
history['val_loss'].append(epoch_val_loss)
history['val_accuracy'].append(epoch_val_accuracy)
print(f"Epoch {epoch:02d} | "
f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {100 * epoch_train_accuracy:.2f}% | "
f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {100 * epoch_val_accuracy:.2f}%")
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_without_improve = 0
# Save the best model’s weights
best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
else:
epochs_without_improve += 1
print(f"No improvement in validation loss for {epochs_without_improve}/{patience} epochs.")
if epochs_without_improve >= patience:
print(f"Early stopping triggered. Restoring best model from epoch {epoch-epochs_without_improve}.")
break
# After loop finishes, load best weights back into the model
if best_state_dict is not None:
model.load_state_dict(best_state_dict)
return model, history