forked from sw32-seo/GTA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGenerate_test_prediction.py
More file actions
28 lines (25 loc) · 957 Bytes
/
Generate_test_prediction.py
File metadata and controls
28 lines (25 loc) · 957 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import subprocess
import os
import argparse
parser = argparse.ArgumentParser(description="Get saved data/model path")
parser.add_argument('--src', '-src', type=str, default='data/USPTO-50k_no_rxn/src-test.txt')
parser.add_argument('--model_path', '-model_path', type=str)
args = parser.parse_args()
src = args.src
models_path = os.path.join(args.model_path, 'models')
a = os.listdir(models_path)
a.sort()
best_model_path = os.path.join(models_path, a[-1])
subprocess.call(['python', 'translate.py',
'-gpu', '0',
'-model', '%s' % best_model_path,
'-src', '%s' % src,
'-output_dir', '%s/pred' % args.model_path,
'-batch_size', '32',
'-replace_unk',
'-max_length', '256',
'-n_best', '50',
'-beam_size', '10',
'-log_probs',
'-n_translate_latent', '0'
])