forked from branislav1991/PyTorchProjectFramework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate.py
More file actions
42 lines (33 loc) · 1.39 KB
/
validate.py
File metadata and controls
42 lines (33 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
from datasets import create_dataset
from utils import parse_configuration
from models import create_model
import os
"""Performs validation of a specified model.
Input params:
config_file: Either a string with the path to the JSON
system-specific config file or a dictionary containing
the system-specific, dataset-specific and
model-specific settings.
"""
def validate(config_file):
print('Reading config file...')
configuration = parse_configuration(config_file)
print('Initializing dataset...')
val_dataset = create_dataset(configuration['val_dataset_params'])
val_dataset_size = len(val_dataset)
print('The number of validation samples = {0}'.format(val_dataset_size))
print('Initializing model...')
model = create_model(configuration['model_params'])
model.setup()
model.eval()
model.pre_epoch_callback(configuration['model_params']['load_checkpoint'])
for i, data in enumerate(val_dataset):
model.set_input(data) # unpack data from data loader
model.test() # run inference
model.post_epoch_callback(configuration['model_params']['load_checkpoint'])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Perform model validation.')
parser.add_argument('configfile', help='path to the configfile')
args = parser.parse_args()
validate(args.configfile)