-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
80 lines (66 loc) · 3.27 KB
/
train.py
File metadata and controls
80 lines (66 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import os
base_models = ["Geoformer", "PaiNN", "Equiformer"]
spectrum_types = ["Naive", "GMM", "FC"]
data_path_list = ['IrDB', 'IrDB_uff', 'IrDB_murcko', 'PtDB']
def get_args():
parser = argparse.ArgumentParser(description="Spectrum prediction based on physics-informed neural network")
parser.add_argument(
"--spectrum-type",
type=str,
default='FC',
choices=spectrum_types,
help="FC: Spectrum prediction based on Franck-Condon progression. GMM: Spectrum prediction based on Gaussian Mixture-Model Naive: Not consider Spectrum Loss.",
)
parser.add_argument(
"--base-model",
type=str,
default=None,
choices=base_models,
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
)
parser.add_argument(
"--data-path",
type=str,
default='IrDB',
choices=data_path_list,
)
return parser.parse_args()
def main():
i_seed = 0
i_fold = 0
args = get_args()
if args.base_model == None or args.spectrum_type == None:
print ("python train.py --base-model {Geoformer|PaiNN|Equiformer} [--spectrum-type {Naive|GMM|FC} (default:FC)] [--batch-size <int> (default:16)] [--data-path {IrDB|IrDB_uff|IrDB_murcko|PtDB} (default:IrDB)]")
print ("Please select base-model from Geoformer, PaiNN, or Equiformer")
print ("Please select spectrum-type from Naive, GMM, or FC")
return
batch_size = args.batch_size
if not args.base_model in base_models:
raise Exception("Undefined model. Please select one from Geoformer, PaiNN, or Equiformer.")
if not args.spectrum_type in spectrum_types:
raise Exception("Undefined spectrum_type. Please select one from Naive, GMM, or FC.")
train_file = f'train_{args.base_model}'
log_path = f'results_{args.base_model}/{i_seed}/{i_fold}'
if args.data_path in ['IrDB', 'IrDB_uff']:
split_npz = f'{args.data_path}/raw/CV811/splits.{i_seed}.{i_fold}.npz'
elif args.data_path == 'IrDB_murcko':
split_npz = f'{args.data_path}/raw/CV_murcko/splits.{i_seed}.{i_fold}.npz'
elif args.data_path == 'PtDB':
split_npz = f'{args.data_path}/raw/CV_10fold/splits.{i_seed}.{i_fold}.npz'
if not os.path.exists(split_npz):
raise Exception(f"{split_npz} file is not exist. Please check the i_seed and i_fold")
if args.base_model == 'Geoformer':
cmd_line = f'python -m train_Geoformer --conf geoformer/examples/{args.spectrum_type}.yml --log-dir {log_path} --seed {i_seed} --splits {split_npz} --batch-size {batch_size} --dataset-root {args.data_path}'
elif args.base_model == 'PaiNN':
cmd_line = f'python -m train_PaiNN --spectrum-type {args.spectrum_type} --output-dir {log_path} --split-index-npz {split_npz} --seed {i_seed} --batch-size {batch_size} --data-path {args.data_path}'
elif args.base_model == 'Equiformer':
cmd_line = f'python -m train_Equiformer --spectrum-type {args.spectrum_type} --output-dir {log_path} --split-index-npz {split_npz} --seed {i_seed} --batch-size {batch_size} --data-path {args.data_path}'
# os.system(f'phd run -p mai_small_gpu -ng 1 -GR "name==H100" -- {cmd_line}')
os.system(f'{cmd_line}')
if __name__ == "__main__":
main()