-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
143 lines (120 loc) · 5.71 KB
/
train.py
File metadata and controls
143 lines (120 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
import shutil
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
from torch import optim
from torch.utils.data import DataLoader
from model import Modified3DUNet
import glob
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import utils
from losses import GeneralizedDiceLoss, dice_loss, FocalLoss
from torch.autograd import Variable
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from utils import get_logger
from sklearn.metrics import confusion_matrix
from utils import get_metrics
from sklearn.metrics import accuracy_score
from utils import normalize
experiment_name = '300_epochs_ce_weights'
logger = get_logger(experiment_name)
writer = SummaryWriter(os.path.join('../data/verse/models/logs/',experiment_name ) )
if not os.path.exists('../data/verse/models/{}'.format(experiment_name)):
os.mkdir(os.path.join('../data/verse/models/',experiment_name))
'''
This will fetch the data and give it to the network -- helps in step 2 of the repo design
'''
# get all the image and mask path and number of images
#i saw the slide and have checked the order
folder_data = glob.glob('../data/verse/patches/images/*.npy')
folder_mask = glob.glob('../data/verse/patches/masks/*.npy')
# split these path using a certain percentage
len_data = len(folder_data)
train_size = 1
train_image_paths = folder_data[:int(len_data*train_size)]
test_image_paths = folder_data[int(len_data*train_size):]
train_mask_paths = folder_mask[:int(len_data*train_size)]
test_mask_paths = folder_mask[int(len_data*train_size):]
class CustomDataset(Dataset):
def __init__(self, image_paths, target_paths, train=True): # initial logic happens like transform
self.image_paths = image_paths
self.target_paths = target_paths
self.transforms = transforms.ToTensor()
def __getitem__(self, index):
image = np.load(self.image_paths[index])
mask = np.load(self.target_paths[index])
image = torch.from_numpy(image)
mask = torch.from_numpy(np.array(mask, dtype=np.uint8))
return image, mask
def __len__(self): # return count of sample we have
return len(self.image_paths)
train_dataset = CustomDataset(train_image_paths, train_mask_paths, train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=12)
test_dataset = CustomDataset(test_image_paths, test_mask_paths, train=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=12)
in_channels = 1
n_classes = 2
base_n_filter = 16
model = Modified3DUNet(in_channels, n_classes, base_n_filter).cuda()
# weights = [0.01, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0,10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0,10.0,10.0,10.0,10.0,10.0]
weights = [0.1, 50.0]
class_weights = torch.FloatTensor(weights).cuda()
# loss_function = GeneralizedDiceLoss(weight = class_weights)
loss_function = nn.CrossEntropyLoss(weight = class_weights)
# loss_function = FocalLoss()
optimizer = optim.Adam(model.parameters())
# criterion = FocalLoss(num_class = 2)
epochs = 300
# model = nn.DataParallel(model, NumerofGPU)
for epoch in range(epochs):
model.train()
logger.info('Starting @ epoch {}'.format(epoch))
start = time.time()
losses = []
mean_accuracy = []
for index,(image, mask) in enumerate(train_loader):
optimizer.zero_grad()
image = normalize(image)
image = torch.unsqueeze(image,0).float().cuda()
label = mask.cuda().long().view(-1)
labels_for_conf = mask
output_1, output_2 = model(image)
#for tverksy loss
loss = loss_function(output_1, label)
# loss = loss_function(output_1, label)
# one_hot_encode_labels = F.one_hot(label,n_classes)
# one_hot_encode_labels = one_hot_encode_labels.permute(0,4,1,2,3).contiguous()
# loss = loss_function(output_2,one_hot_encode_labels)
softmax = nn.Softmax(dim=1)
output_2 = softmax(output_2)
conf_matrix = confusion_matrix(torch.argmax(output_2,1).view(-1).cpu().detach().numpy(), labels_for_conf.view(-1).cpu().detach().numpy())
TPR,TNR, PPV, FPR ,FNR, ACC = get_metrics(conf_matrix)
accuracy = accuracy_score(labels_for_conf.view(-1).cpu().detach().numpy(), torch.argmax(output_2,1).view(-1).cpu().detach().numpy())
mean_accuracy.append(accuracy)
logger.info('TPR =={}'.format(TPR))
logger.info('TNR=={}'.format(TNR))
logger.info('TPR == {} | \nTNR == {} | \nPRCSN == {} | \nFPR == {}\n | \nFNR == {} | \nACC == {}.'.format(TPR,TNR,PPV,FPR, FNR, ACC))
logger.info('Epoch = {} , Accuracy = {}'.format(epoch, accuracy))
losses.append(loss.item())
loss.backward()
optimizer.step()
writer.add_scalar('Train/Loss', loss, epoch)
writer.add_scalar('Train/Accuracy', accuracy, epoch)
logger.info('Mean Loss = {}'.format(sum(losses) / float(len(losses))))
logger.info('Mean Accuracy = {}'.format(sum(mean_accuracy) / float(len(mean_accuracy))))
end = time.time() - start
logger.info('Time taken to finish epoch {} is {}'.format(epoch, end))
average_accuracy = sum(mean_accuracy) / float(len(mean_accuracy))
average_loss = sum(losses) / float(len(losses))
#borrowed from https://medium.com/udacity-pytorch-challengers/saving-loading-your-model-in-pytorch-741b80daf3c
checkpoint = {'epoch': epoch, 'state_dict' :model.state_dict(), 'optimizer':optimizer.state_dict(), 'accuracy': average_accuracy, 'loss': average_loss}
torch.save(checkpoint, '../data/verse/models/{}/epoch_{}_checkpoint.pth'.format(experiment_name, epoch))
writer.close()