forked from michjk/Question_Classifier_Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_cnn.py
More file actions
62 lines (41 loc) · 2.2 KB
/
train_cnn.py
File metadata and controls
62 lines (41 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import time
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from model_module.cnn_classifier import CNNClassifier
from data_module.data_preprocessor import *
import os
import random
import datetime
from model_module.model_runner import ModelRunner
from utils import load_training_parameter_from_json, filter_dotdict_class_propoperty, FactoryClass
from data_module.data_writer import LearningWriter, PreprocessingPipelineWriter
import argparse
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
random.seed(1)
parser = argparse.ArgumentParser()
parser.add_argument("--path", help="path parameter json file")
param_json_path = parser.parse_args().path
param = load_training_parameter_from_json(param_json_path)
cnn_parameter = filter_dotdict_class_propoperty(param, CNNClassifier)
preprocessing_pipeline_writer = PreprocessingPipelineWriter(param.result_folder_path, param.saved_text_pipeline_file_path, param.saved_label_pipeline_file_path)
train_data, dev_data, vocab_size, label_size, label_map, pretrained_embedding_weight = load_dataset(
param.train_dataset_path, param.dev_dataset_path, param.max_text_length, preprocessing_pipeline_writer,
pretrained_word_embedding_name = param.pretrained_word_embedding_name, pretrained_word_embedding_path = param.pretrained_word_embedding_path
)
cnn_parameter.vocab_size = vocab_size
cnn_parameter.label_size = label_size
cnn_parameter.pretrained_embedding_weight = pretrained_embedding_weight
model_factory = FactoryClass(CNNClassifier, cnn_parameter)
loss_factory = FactoryClass(nn.NLLLoss)
optimizer_param_dict = filter_dotdict_class_propoperty(param, optim.Adam)
optimizer_factory = FactoryClass(optim.Adam, optimizer_param_dict)
learning_logger = LearningWriter(label_map, param.result_folder_path, param.saved_model_file_path, param.train_log_folder_path, param.dev_log_folder_path, param.confusion_matrix_folder_path)
model_runner = ModelRunner(model_factory, loss_factory, optimizer_factory, param.epoch, param.batch_size, learning_logger, param.use_gpu)
start_time = time.time()
model_runner.learn(train_data, dev_data)
print("Overall time elapsed {} sec".format(time.time() - start_time))