-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmain_test.py
More file actions
98 lines (87 loc) · 3.59 KB
/
main_test.py
File metadata and controls
98 lines (87 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import shutil
import os
import time
import torch
import warnings
import torch.nn as nn
import torch.nn.parallel
import torch.optim
from models.VGG_models import *
import data_loaders
from functions import TET_loss, seed_all
from main_training_parallel import train, test
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='PyTorch Temporal Efficient Training')
parser.add_argument('-j',
'--workers',
default=16,
type=int,
metavar='N',
help='number of data loading workers (default: 10)')
parser.add_argument('--epochs',
default=150,
type=int,
metavar='N',
help='number of total epochs to run')
parser.add_argument('--start_epoch',
default=0,
type=int,
metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b',
'--batch_size',
default=128,
type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr',
'--learning_rate',
default=0.001,
type=float,
metavar='LR',
help='initial learning rate',
dest='lr')
parser.add_argument('--seed',
default=1000,
type=int,
help='seed for initializing training. ')
parser.add_argument('-T',
'--time',
default=2,
type=int,
metavar='N',
help='snn simulation time (default: 2)')
parser.add_argument('--means',
default=1.0,
type=float,
metavar='N',
help='make all the potential increment around the means (default: 1.0)')
parser.add_argument('--TET',
default=True,
type=bool,
metavar='N',
help='if use Temporal Efficient Training (default: True)')
parser.add_argument('--lamb',
default=1e-3,
type=float,
metavar='N',
help='adjust the norm factor to avoid outlier (default: 0.0)')
args = parser.parse_args()
if __name__ == '__main__':
seed_all(args.seed)
train_dataset, val_dataset = data_loaders.build_dvscifar('cifar-dvs') # change to your path
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, pin_memory=True)
model = VGGSNNwoAP()
state_dict = torch.load('VGGSNN_woAP.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
parallel_model = torch.nn.DataParallel(model)
parallel_model.to(device)
facc = test(parallel_model, test_loader, device)
print('Test Accuracy of the model: %.3f' % facc)