This repository was archived by the owner on Jul 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
110 lines (97 loc) · 3.86 KB
/
main.py
File metadata and controls
110 lines (97 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch import nn, optim
from torchvision import transforms
import os
import logging
import argparse
from GCDataLoader import get_loaders
from ParseConstants import get_model, get_optimizer
import pandas as pd
# 设置log
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(filename)s:%(lineno)d [%(levelname)s] %(message)s')
# 设置参数解析
parser = argparse.ArgumentParser(
description='Train a garbage classification model.')
parser.add_argument('--model', type=str, default='resnet18',
help='Choose model (from resnet18(default) resnet50 resnet101)')
parser.add_argument('--bs', type=int, default=16,
help='Batch size (default 16)')
parser.add_argument('-aug', action='store_const',
const=True, default=False, help='Use image augmentation')
parser.add_argument('--lr', type=float, default=1e-4,
help='Learning rate (default 1e-4)')
parser.add_argument('--optim', type=str, default='SGD',
help='Choose optimizer (SGD(default) Adam Adadelta)')
parser.add_argument('--scheduler', type=str, default='None',
help='Chose scheduler (StepLR ExpLr CosLR Plateau)')
parser.add_argument('--epoch', type=int, default=20,
help='Number of Epoches (defualt 20)')
parser.add_argument('-cbam', action='store_const',
const=True, default=False, help='Use CBAM.')
parser.add_argument('--o', type=str, default='',
help='Output file')
args = parser.parse_args()
BATCH_SIZE = args.bs
logging.info('begin to load data...')
train_loader, test_loader = get_loaders(0.8, BATCH_SIZE, args.aug)
train_len = len(train_loader.dataset)
test_len = len(test_loader.dataset)
logging.info(f'size of training set: {train_len}')
logging.info(f'size of testing set: {test_len}')
device = torch.device('cuda')
# model = ResNet18().to(device)
model = get_model(args.model, args.cbam)
num_in = model.fc.in_features
model.fc = nn.Linear(num_in, 40)
model = model.to(device)
# print(model, flush= True)
use_scheduler = args.scheduler != 'None'
save_results = args.o != ''
if use_scheduler:
optimizer, scheduler = get_optimizer(
model, args.lr, args.scheduler, args.optim)
else:
optimizer = get_optimizer(model, args.lr, args.scheduler, args.optim)
criteon = nn.CrossEntropyLoss()
if save_results:
logging.info('save results to file: ' + args.o)
df = pd.DataFrame(columns=('epoch', 'train_loss',
'test_loss', 'test_accuracy'))
logging.info('begin to train...')
for epoch in range(args.epoch):
train_loss = 0
test_loss = 0
model.train()
for x, label in train_loader:
x, label = x.to(device), label.to(device)
logits = model(x)
loss = criteon(logits, label)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if use_scheduler:
scheduler.step()
model.eval()
with torch.no_grad():
total_corret = 0
total_num = 0
for x, label in test_loader:
x, label = x.to(device), label.to(device)
logits = model(x)
loss = criteon(logits, label)
test_loss += loss.item()
pred = logits.argmax(dim=1)
correct = torch.eq(pred, label).float().sum().item()
total_corret += correct
total_num += x.size(0)
acc = total_corret / total_num
logging.info(
f'[{epoch}] train_loss: {train_loss / train_len}, test_loss: {test_loss / test_len}, test_accuracy: {acc}')
if save_results:
df = df.append([{'epoch': epoch, 'train_loss': train_loss / train_len,
'test_loss': test_loss / test_len, 'test_accuracy': acc}], ignore_index=True)
if save_results:
df.set_index(['epoch'], inplace=True)
df.to_csv(args.o)