-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathstep3b_predict_heatmaps.py
More file actions
24 lines (19 loc) · 1002 Bytes
/
step3b_predict_heatmaps.py
File metadata and controls
24 lines (19 loc) · 1002 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import argparse
from scripts.data_loading import MemBrain_datamodule
from scripts.trainer import MemBrainer
from config import *
parser = argparse.ArgumentParser()
parser.add_argument("ckpt", type=str, help="Path to trained model checkpoint.")
args = parser.parse_args()
def main():
project_directory = os.path.join(PROJECT_DIRECTORY, PROJECT_NAME)
out_star_name = os.path.join(os.path.join(project_directory, 'rotated_volumes'),
PROJECT_NAME + '_with_inner_outer.star')
heatmap_out_dir = os.path.join(project_directory, 'heatmaps')
dm = MemBrain_datamodule(out_star_name, BATCH_SIZE, part_dists=TRAINING_PARTICLE_DISTS, max_dist=MAX_PARTICLE_DISTANCE)
trainer = MemBrainer(box_range=BOX_RANGE, dm=dm, project_dir=project_directory, star_file=out_star_name,
ckpt=args.ckpt, part_dists=TRAINING_PARTICLE_DISTS)
trainer.predict(heatmap_out_dir, star_file=out_star_name)
if __name__ == '__main__':
main()