-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinference.py
More file actions
68 lines (58 loc) · 2.49 KB
/
inference.py
File metadata and controls
68 lines (58 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import os
from datasets import get_dataset, DATASETS, get_num_classes
from architectures_unstructured import ARCHITECTURES, get_architecture
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer, Adam
from torch.optim.lr_scheduler import StepLR
import datetime
import time
import numpy as np
import copy
import types
from math import ceil
from train_utils import AverageMeter, accuracy, accuracy_list, init_logfile, log
from utils import *
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('savedir', type=str, help='folder to load model')
parser.add_argument('--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--batch', default=256, type=int, metavar='N',
help='batchsize (default: 256)')
parser.add_argument('--gpu', default=0, type=int,
help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--stride', type=int, default=1, help='conv1 stride')
args = parser.parse_args()
def relu_counting(net, args):
relu_count = 0
for name, param in net.named_parameters():
if 'alpha' in name:
boolean_list = param.data > args.threshold
relu_count += (boolean_list == 1).sum()
return relu_count
def main():
device = torch.device("cuda")
torch.cuda.set_device(args.gpu)
test_dataset = get_dataset(args.dataset, 'test')
pin_memory = (args.dataset == "imagenet")
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
num_workers=args.workers, pin_memory=pin_memory)
# Loading the base_classifier
base_classifier = get_architecture(args.arch, args.dataset, device, args)
checkpoint = torch.load(args.savedir, map_location=device)
base_classifier.load_state_dict(checkpoint['state_dict'])
base_classifier.eval()
print("Loaded the base_classifier")
# Calculating the loaded model's test accuracy.
original_acc = model_inference(base_classifier, test_loader,
device, display=True)
if __name__ == "__main__":
main()