-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
85 lines (61 loc) · 2.51 KB
/
train.py
File metadata and controls
85 lines (61 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
import os
import copy
import logging
from time import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt
from data.data import create_dataset
from seed import seed_everything
from options.train_options import TrainOptions
from DropLR import DropLR
from Trainer import Trainer
TORCH_SEED = 0
if __name__=="__main__":
seed_everything(TORCH_SEED)
train_opt = TrainOptions().parse()
val_opt = TrainOptions().parse()
val_opt.isTrain = False
train_logger = logging.getLogger("Train_Logger")
train_logger.setLevel(logging.INFO)
logs_dir = f'{train_opt.logs_dir}'
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
train_fh = logging.FileHandler(filename=f'{train_opt.logs_dir}/{train_opt.log_file}.log', encoding='utf-8')
train_fh.setLevel(logging.INFO)
train_logger.addHandler(train_fh)
opts = {'train': train_opt, 'val': val_opt}
writer = SummaryWriter(f'runs/{train_opt.name}/train')
data_directory = "dataset"
datasets = ["train", "val"]
dataloaders = {}
dataset_sizes = {}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classes = train_opt.classes
train_logger.info(f'Classes: {classes}')
for dataset in datasets:
dst_start = time()
dataloaders[dataset], dataset_sizes[dataset] = create_dataset(f"{data_directory}/{dataset}", opts[dataset], opts[dataset].classes)
train_logger.info(f'[==== DATASET CREATION DURATION FOR {dataset} = {time() - dst_start} ====]')
# Import the pretrained ResNet50 Model trained on the ImageNet Dataset
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 1)
model = model.to(device)
m = nn.Sigmoid()
criterion = nn.BCEWithLogitsLoss()
if train_opt.optim == "adam":
optimizer = optim.Adam(model.parameters(), lr=train_opt.lr)
elif train_opt.optim == "sgd":
optimizer = optim.SGD(model.parameters(), lr=train_opt.lr)
else:
train_logger.info('Optimizer not valid!')
sys.exit()
step_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
trainer = Trainer(train_opt, dataloaders, dataset_sizes, device, train_logger)
model = trainer.train_model(model, m, criterion, optimizer, step_lr_scheduler, writer)