-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
129 lines (105 loc) · 3.79 KB
/
inference.py
File metadata and controls
129 lines (105 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import sys
import torch
import hydra
from omegaconf import DictConfig, OmegaConf
from models.utils.reconstruct import (
extract_slices,
reconstruct_volume_nn,
save_volume,
reconstruct_volume_sr_fast,
reconstruct_volume_linear_fast,
)
from models.utils.utils import strip_orig_mod
def run_inference(model, cfg: DictConfig, device: torch.device):
image_path = cfg.inference.image_path
save_path = cfg.inference.save_path
model_path = cfg.inference.model_path
model_dir_name = os.path.basename(os.path.dirname(model_path))
if model_dir_name:
model_label = model_dir_name
else:
model_label = os.path.splitext(os.path.basename(model_path))[0]
model_label = model_label.replace(" ", "_")
patient_name = os.path.basename(image_path.rstrip("/"))
print(f"Loading data from {image_path}...")
extracted_slices = extract_slices(image_path)
print("Data loaded successfully")
print("--------------------------------")
depth = cfg.inference.get("volume_depth")
ap_voxel_count = cfg.inference.get("ap_voxel_count")
probe_radius = cfg.inference.get("probe_radius", 12.5)
patch_size = cfg.data.patch_size
stride = cfg.inference.stride
if cfg.inference.get("nn", False):
print("Saving original volume (NN)...")
save_volume(
extracted_slices,
reconstruct_volume_nn(
extracted_slices,
depth=depth,
ap_voxel_count=ap_voxel_count,
probe_radius=probe_radius,
),
os.path.join(save_path, f"{patient_name}_nn"),
)
print("Original volume saved successfully")
print("--------------------------------")
if cfg.inference.get("linear", False):
print("Saving linear interpolation volume...")
save_volume(
extracted_slices,
reconstruct_volume_linear_fast(
extracted_slices,
depth=depth,
ap_voxel_count=ap_voxel_count,
patch_size=patch_size,
stride=stride,
probe_radius=probe_radius,
device=device,
),
os.path.join(save_path, f"{patient_name}_sr_linear"),
)
print("Linear interpolation volume saved successfully")
print("--------------------------------")
print("Beginning model inference...")
model.eval()
save_volume(
extracted_slices,
reconstruct_volume_sr_fast(
model,
extracted_slices,
depth=depth,
ap_voxel_count=ap_voxel_count,
patch_size=patch_size,
stride=stride,
probe_radius=probe_radius,
),
os.path.join(save_path, f"{patient_name}_{model_label}_sr"),
)
print("Done!")
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("--------------------------------")
print("Building model...")
model = hydra.utils.instantiate(cfg.model).to(device)
# Load Checkpoint
model_path = cfg.inference.model_path
print(f"Loading weights from: {model_path}")
checkpoint = torch.load(model_path, map_location=device)
if "model_state" in checkpoint:
state_dict = checkpoint["model_state"]
elif "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
model.load_state_dict(strip_orig_mod(state_dict))
model = torch.compile(model)
print("Model built and loaded successfully")
print("--------------------------------")
run_inference(model, cfg, device)
if __name__ == "__main__":
main()