-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
100 lines (77 loc) · 3.52 KB
/
train.py
File metadata and controls
100 lines (77 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import time
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm
from logger import Logger
import warnings
warnings.filterwarnings('ignore')
def train(model, train_loader, val_loader, num_epochs, optimizer, device):
num_training_steps = num_epochs * len(train_loader)
num_batches = len(train_loader)
progress_bar = tqdm(range(num_training_steps))
start_time = time.time()
model.train()
logger = Logger(model_name = 'distilbert-base-uncased', data_name = 'Reddit_Train_')
for epoch in range(num_epochs):
epoch_start = time.time()
train_acc = 0.0
train_losses = 0.0
for i, (batch) in enumerate(train_loader):
batch_start = time.time()
model = model.to(device)
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(**batch)
# print(outputs)
logits = outputs.logits
# print(logits.size())
predictions = torch.argmax(logits, dim = 1)
targets = batch["labels"]
batch_acc = ((predictions == torch.transpose(targets, 0, 1)).sum()/ len(targets)).item()
train_acc += batch_acc
loss = outputs.loss
loss.backward()
train_losses+= loss.item()
optimizer.step()
progress_bar.update(1)
logger.log(loss, batch_acc, epoch, i, num_batches)
if i % 100 == 0:
logger.display_status(epoch, num_epochs, i, num_batches, loss, batch_acc)
epoch_end = time.time()
print(
f' ########### End of Epoch {epoch + 1} ############### \
| Epochs: {epoch + 1} | Training Batch Loss: {train_losses / len(train_loader): .4f} \
| Epoch Training Time: {epoch_end - epoch_start} s \
| Train Accuracy: {train_acc / len(train_loader): .4f} \
| Epoch Training Time: {epoch_end - epoch_start} s')
val_accs = 0.0
val_losses = 0.0
best_val_loss = torch.tensor(float("inf"))
best_val_acc = 0.0
print("Validation progress ...")
val_start = time.time()
with torch.no_grad():
model.eval()
for i, (batch) in enumerate(val_loader):
model = model.to(device)
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim = 1)
targets = batch['labels']
val_acc = ((predictions == torch.transpose(targets, 0, 1)).sum()/ len(targets)).item()
val_accs += val_acc
loss = outputs.loss
val_losses+= loss.item()
val_end = time.time()
if loss.item() < best_val_loss:
best_val_loss = loss.item()
best_val_acc = val_acc
logger.save_models(model, epoch, i)
print(f'Total Average Validation Loss: {val_losses/len(val_loader)}')
print(f'Total Average Validation Accuracy: {val_accs/len(val_loader)}')
print(f'Best Validation Loss: {best_val_loss}')
print(f'Best Validation Accuracy: {best_val_acc}')
print(f'Validation Time: {val_end - val_start} s')
print(f"Total Training Time: {time.time() - start_time} s")