-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
382 lines (340 loc) · 22.2 KB
/
eval.py
File metadata and controls
382 lines (340 loc) · 22.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import os
import time
from argparse import ArgumentParser
import json
import time
import random
import shutil
from easydict import EasyDict
import torch
import numpy as np
import cardio_volume_skewer
import four_d_ct_cost_unrolling
import three_d_data_manager as dt_mng
import flow_n_corr_utils
import p2p_correspondence
from utils import remove_non_floats_from_dict, json2dict, handle_pred_flow
import dataset_obj_adding as doa
parser = ArgumentParser()
parser.add_argument("--config_path", type=str, default="configs/config_templates/general_config_template.json")
parser.add_argument("--cuda_device", type=str, default=0)#None)
parser.add_argument("--MODE", type=str, default="TRAIN_AND_EVAL") #"TRAIN", "EVAL", "TRAIN_AND_EVAL", INFER
args = parser.parse_args()
config = json2dict(args.config_path)
if args.cuda_device is not None: # override json cuda device with cmd line cuda device
config.general.cuda_device = int(args.cuda_device)
MODE = args.MODE
time.sleep(random.random()*3) # avoid simultanious init
top_output_dir = config.general.top_output_dir
outputs_path = os.path.join(top_output_dir, f"outputs_{time.strftime('%Y%m%d_%H%M%S')}")
os.makedirs(outputs_path, exist_ok=True)
print(f"Output directory: {outputs_path}")
dataset_target_path = os.path.join(top_output_dir, config.dataset.dataset_target_folder_name)
template_timestep_name = config.dataset.template_timestep_name
unlabeled_timestep_name = "28"
dt_mng.write_config_file(outputs_path, "main_train", config)
torch.set_num_threads(config.general.torch_num_threads)
print("Init dataset")
dataset = dt_mng.Dataset(target_root_dir=dataset_target_path, file_paths=dt_mng.read_filepaths(dataset_target_path))
if config.dataset.template_dicom_path is not None: # magix
abs_template_dicom_path = os.path.join(os.getcwd(), config.dataset.template_dicom_path)
dataset = doa.add_template_image_from_dicom(dataset, template_timestep_name, abs_template_dicom_path)
else: # mm-whs, 3d_slicer
dataset = doa.add_image_from_xyz_arr(dataset, template_timestep_name, config.dataset.template_xyz_arr_path) # note that dim_x==dim_y
if config.dataset.template_zxy_voxels_mask_arr_path is not None: # magix
dataset = doa.add_mask_from_zxy_arr(dataset, template_timestep_name, config.dataset.template_zxy_voxels_mask_arr_path, config.dataset.voxels_mask_smoothing, mask_or_extra_mask="mask") # note that dim_x==dim_y
dataset = doa.add_mask_from_zxy_arr(dataset, template_timestep_name, config.dataset.template_zxy_voxels_extra_mask_arr_path, config.dataset.voxels_mask_smoothing, mask_or_extra_mask="extra_mask") # note that dim_x==dim_y
else: # mm-whs, 3d_slicer
dataset = doa.add_mask_from_xyz_arr(dataset, template_timestep_name, os.path.join(os.getcwd(),config.dataset.template_xyz_voxels_mask_arr_path), config.dataset.voxels_mask_smoothing, mask_or_extra_mask="mask")
dataset = doa.add_mask_from_xyz_arr(dataset, template_timestep_name, os.path.join(os.getcwd(),config.dataset.template_xyz_voxels_extra_mask_arr_path), config.dataset.voxels_mask_smoothing, mask_or_extra_mask="extra_mask")
template_3dimg_path = dataset.file_paths.xyz_arr[template_timestep_name] #TODO read from filepaths backup so no add_template_image, add_mask will be nesseccary
template_mask_path = dataset.file_paths.xyz_voxels_mask_smooth[template_timestep_name]
template_extra_mask_path = dataset.file_paths.xyz_voxels_extra_mask_smooth[template_timestep_name]
print("Generating synthetic dataset")
template_synthetic_img_path, unlabeled_synthetic_img_path, \
template_synthetic_mask_path, unlabeled_synthetic_mask_path, \
template_synthetic_extra_mask_path, unlabeled_synthetic_extra_mask_path, synthetic_flow_path, \
error_radial_coordinates_path, error_circumferential_coordinates_path, error_longitudinal_coordinates_path = \
cardio_volume_skewer.create_skewed_sequences(
r1s_end=config.synthetic_data.r1,
r2s_end=config.synthetic_data.r2,
theta1s_end=config.synthetic_data.theta1,
theta2s_end=config.synthetic_data.theta2,
hs_end=config.synthetic_data.h,
output_dir=dataset_target_path,
template_3dimage_path=template_3dimg_path,
template_mask_path=template_mask_path,
template_extra_mask_path=template_extra_mask_path,
num_frames=6,
zero_outside_mask=config.synthetic_data.zero_outside_mask,
blur_around_mask_radious=config.synthetic_data.blur_around_mask_radious,
theta_distribution_method=config.synthetic_data.theta_distribution_method,
scale_down_by=1
)
print("Processing synthetic dataset")
dataset = doa.add_image_from_xyz_arr(dataset, unlabeled_timestep_name, unlabeled_synthetic_img_path)
dataset = doa.add_mask_from_xyz_arr( dataset, unlabeled_timestep_name, unlabeled_synthetic_mask_path, config.dataset.voxels_mask_smoothing, mask_or_extra_mask="mask")
dataset = doa.add_mask_from_xyz_arr( dataset, unlabeled_timestep_name, unlabeled_synthetic_extra_mask_path, config.dataset.voxels_mask_smoothing, mask_or_extra_mask="extra_mask")
unlabeled_3dimg_path = dataset.file_paths.xyz_arr[unlabeled_timestep_name]
unlabeled_mask_path = dataset.file_paths.xyz_voxels_mask_smooth[unlabeled_timestep_name]
unlabeled_extra_mask_path = dataset.file_paths.xyz_voxels_extra_mask_smooth[unlabeled_timestep_name]
print("Creating meshes")
for timestep_name in unlabeled_timestep_name, template_timestep_name:
print("Creating mesh")
mesh_creation_args = dt_mng.MeshSmoothingCreationArgs(marching_cubes_step_size=2) #1)
mesh_data_creator = dt_mng.MeshDataCreator(source_path=None, sample_name=timestep_name, hirarchy_levels=2, creation_args=mesh_creation_args)
dataset.add_sample(mesh_data_creator)
print("Creating LBOs from mesh")
lbo_creation_args = dt_mng.LBOCreationArgs(num_LBOs=config.dataset.num_lbos, is_point_cloud=False, geometry_path=dataset.file_paths.mesh[timestep_name], orig_geometry_name="mesh", use_torch=True)
lbos_data_creator = dt_mng.LBOsDataCreator(source_path=None, sample_name=timestep_name, hirarchy_levels=2, creation_args=lbo_creation_args)
dataset.add_sample(lbos_data_creator)
if config.dataset.mesh_smoothing:
print("Smoothing mesh with lbos")
smooth_mesh_creation_args = dt_mng.SmoothMeshCreationArgs(lbos_path=dataset.file_paths.mesh_lbo_data[timestep_name])
smooth_lbo_mesh_data_creator = dt_mng.SmoothLBOMeshDataCreator(source_path=None, sample_name=timestep_name, hirarchy_levels=2, creation_args=smooth_mesh_creation_args)
dataset.add_sample(smooth_lbo_mesh_data_creator)
print("Computing vertex normals")
vertex_normals_creation_args = dt_mng.VertexNormalsCreationArgs(geometry_path=dataset.file_paths.mesh_smooth[timestep_name], orig_geometry_name="mesh_smooth" if config.dataset.mesh_smoothing else "mesh")
vertex_normals_data_creator = dt_mng.VertexNormalsDataCreator(source_path=None, sample_name=timestep_name, hirarchy_levels=2, creation_args=vertex_normals_creation_args)
dataset.add_sample(vertex_normals_data_creator)
mesh_filename = "smooth_mesh" if config['dataset'].mesh_smoothing else "mesh"
template_mesh_path = os.path.join(dataset_target_path, template_timestep_name, f"orig/meshes/{mesh_filename}.off")
unlabeled_mesh_path = os.path.join(dataset_target_path, unlabeled_timestep_name, f"orig/meshes/{mesh_filename}.off")
template_normals_path = os.path.join(dataset_target_path, template_timestep_name, f"orig/vertices_normals/vertices_normals_from_{'mesh_smooth' if config['dataset'].mesh_smoothing else 'mesh'}.npy")
print("Using ZoomOut")
config_zoomout = p2p_correspondence.get_default_config()
config_zoomout["plots"] = True
config_zoomout["main_output_dir"] = os.path.join(outputs_path, "zoomout_output_dir")
config_zoomout["default_output_subdir"] = config.dataset.dataset_target_folder_name
config_zoomout["process_params"]["descr_type"] = config.zoomout.descriptor_type
config_zoomout["process_params"]["n_ev"] = config.zoomout.num_eigenvectors
config_zoomout["process_params"]["n_descr"] = config.zoomout.num_preprocess_descriptors
config_zoomout["fm_fit_params"]["optinit"] = config.zoomout.optinit
config_zoomout["fm_fit_params"]["w_descr"] = config.zoomout.w_descr
config_zoomout["fm_fit_params"]["w_lap"] = config.zoomout.w_lap
config_zoomout["zoomout_refine_params"]["nit"] = config.zoomout.num_zoomout_iters
config_zoomout["preprocess"]["normalize_meshes_area"] = config.zoomout.normalize_meshes_area
config_zoomout["validation"]["mean_l1_flow_th"] = config.zoomout.mean_l1_flow_th
corr_infer_output_path, valid_flow = p2p_correspondence.get_correspondence(
mesh1_path=template_mesh_path,
mesh2_path=unlabeled_mesh_path,
config=EasyDict(config_zoomout)
)
if not(valid_flow):
print("Retrying ZoomOut with more eigenvectors")
config_zoomout["process_params"]["n_ev"] = config.zoomout.num_eigenvectors_for_2nd_try
corr_infer_output_path, valid_flow = p2p_correspondence.get_correspondence(
mesh1_path=template_mesh_path,
mesh2_path=unlabeled_mesh_path,
config=EasyDict(config_zoomout)
)
print("Converting correspondence to constraints")
sample_shape = dataset.get_xyz_arr(template_timestep_name).shape
config.constraints_creation.confidence_matrix_manipulations_config["plot_folder"] = outputs_path
two_d_constraints_path = flow_n_corr_utils.convert_corr_to_constraints(
correspondence_h5_path=os.path.join(corr_infer_output_path, "model_inference.hdf5"),
k_nn=config.constraints_creation.k_smooth_constraints_nn,
output_folder_path=outputs_path,
output_constraints_shape=(*sample_shape, 3),
k_interpolate_sparse_constraints_nn=config.constraints_creation.k_interpolate_sparse_constraints_nn,
confidence_matrix_manipulations_config=config.constraints_creation.confidence_matrix_manipulations_config
)
voxelized_normals_path = flow_n_corr_utils.voxelize_and_visualize_3d_vecs(
vectors_cloud=np.load(template_normals_path),
point_cloud=dt_mng.read_off(template_mesh_path)[0],
output_shape=(*sample_shape, 3),
text_vis="normals",
output_arr_filename="normals",
output_folder=outputs_path
)
print("Training without constraints")
config_backbone = four_d_ct_cost_unrolling.get_default_backbone_config()
config_backbone["save_iter"] = 2
config_backbone["inference_args"]["inference_flow_median_filter_size"] = False
config_backbone["epochs"] = config.fourD_ct_cost_unrolling.backbone.early_stopping.epochs
config_backbone["valid_type"] = "synthetic+basic"
config_backbone["w_sm_scales"] = config.fourD_ct_cost_unrolling.backbone["w_sm_scales"]
config_backbone["output_root"] = os.path.join(outputs_path, config_backbone["output_root"])
config_backbone["visualization_arrow_scale_factor"] = 1
config_backbone["cuda_device"] = config.general.cuda_device
config_backbone["scale_down_by"] = config.fourD_ct_cost_unrolling.backbone.scale_down_by
config_backbone["metric_for_early_stopping"] = config.fourD_ct_cost_unrolling.backbone.early_stopping.metric_for_early_stopping
config_backbone["max_metric_not_dropping_patience"] = config.fourD_ct_cost_unrolling.backbone.early_stopping.max_metric_not_dropping_patience
if "TRAIN" in MODE:
backbone_model_output_path = four_d_ct_cost_unrolling.overfit_backbone(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_backbone)
)
print("Backbone model output path:", backbone_model_output_path)
if "EVAL" in MODE:
config_backbone["valid_type"] = "synthetic"
config_backbone["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(backbone_model_output_path)
print(f"Evaluating backbone with checkpoints: {config_backbone['load']}")
errors_backbone = four_d_ct_cost_unrolling.validate_backbone(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_backbone)
)
if "INFER" in MODE:
config_backbone["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(backbone_model_output_path)
print(f"Infer backbone with checkpoints: {config_backbone['load']}")
four_d_ct_cost_unrolling.infer_backbone(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
flows_gt_path=synthetic_flow_path,
args=EasyDict(config_backbone)
)
print("Training with anatomical loss (segmentation based)")
config_w_segmentation = four_d_ct_cost_unrolling.get_default_w_segmentation_config()
config_w_segmentation["save_iter"] = 2
config_w_segmentation["inference_args"]["inference_flow_median_filter_size"] = False
config_w_segmentation["epochs"] = config.fourD_ct_cost_unrolling.w_segmentation.early_stopping.epochs
config_w_segmentation["valid_type"] = "synthetic+basic"
config_w_segmentation["w_sm_scales"] = config.fourD_ct_cost_unrolling.w_segmentation["w_sm_scales"]
config_w_segmentation["output_root"] = os.path.join(outputs_path, config_w_segmentation["output_root"])
config_w_segmentation["visualization_arrow_scale_factor"] = 1
config_w_segmentation["cuda_device"] = config.general.cuda_device
config_w_segmentation["scale_down_by"] = config.fourD_ct_cost_unrolling.w_segmentation.scale_down_by
config_w_segmentation["metric_for_early_stopping"] = config.fourD_ct_cost_unrolling.w_segmentation.early_stopping.metric_for_early_stopping
config_w_segmentation["max_metric_not_dropping_patience"] = config.fourD_ct_cost_unrolling.w_segmentation.early_stopping.max_metric_not_dropping_patience
if "TRAIN" in MODE:
w_segmentation_model_output_path = four_d_ct_cost_unrolling.overfit_w_seg(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_w_segmentation)
)
print("Segmentation-based model output path:", w_segmentation_model_output_path)
if "EVAL" in MODE:
config_w_segmentation["valid_type"] = "synthetic"
config_w_segmentation["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(w_segmentation_model_output_path)
print(f"Evaluating segmentation-based with checkpoints: {config_w_segmentation['load']}")
errors_w_seg = four_d_ct_cost_unrolling.validate_w_seg(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_w_segmentation)
)
print("Training with constraints loss")
config_constraints = four_d_ct_cost_unrolling.get_default_w_constraints_config()
config_constraints["save_iter"] = 2
config_constraints["inference_args"]["inference_flow_median_filter_size"] = False
config_constraints["epochs"] = config.fourD_ct_cost_unrolling.w_constraints.early_stopping.epochs
config_constraints["valid_type"] = "synthetic+basic"
config_constraints["w_sm_scales"] = config.fourD_ct_cost_unrolling.w_constraints["w_sm_scales"]
config_constraints["output_root"] = os.path.join(outputs_path, config_constraints["output_root"])
config_constraints["visualization_arrow_scale_factor"] = 1
config_constraints["w_constraints_scales"] = [100.0, 100.0, 100.0, 100.0, 100.0]
config_constraints["cuda_device"] = config.general.cuda_device
config_constraints["scale_down_by"] = config.fourD_ct_cost_unrolling.w_constraints.scale_down_by
config_constraints["metric_for_early_stopping"] = config.fourD_ct_cost_unrolling.w_constraints.early_stopping.metric_for_early_stopping
config_constraints["max_metric_not_dropping_patience"] = config.fourD_ct_cost_unrolling.w_constraints.early_stopping.max_metric_not_dropping_patience
if "TRAIN" in MODE:
config_constraints["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(backbone_model_output_path)
constraints_model_output_path = four_d_ct_cost_unrolling.overfit_w_constraints(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
two_d_constraints_path=two_d_constraints_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_constraints)
)
print(f"Constraints model output path: {constraints_model_output_path}")
if "EVAL" in MODE:
config_constraints["valid_type"] = "synthetic"
config_constraints["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(constraints_model_output_path)
print(f"Evaluating constraints with checkpoints: {config_constraints['load']}")
errors_w_constraints = four_d_ct_cost_unrolling.validate_w_constraints(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
two_d_constraints_path=two_d_constraints_path,
flows_gt_path=synthetic_flow_path,
error_radial_coordinates_path=error_radial_coordinates_path,
error_circumferential_coordinates_path=error_circumferential_coordinates_path,
error_longitudinal_coordinates_path=error_longitudinal_coordinates_path,
voxelized_normals_path=voxelized_normals_path,
args=EasyDict(config_constraints)
)
if "INFER" in MODE:
config_constraints["load"] = four_d_ct_cost_unrolling.get_checkpoints_path(constraints_model_output_path)
print(f"Infer constraints with checkpoints: {config_constraints['load']}")
infer_constraints_model_output_path = four_d_ct_cost_unrolling.infer_w_constraints(
template_image_path=unlabeled_synthetic_img_path,
unlabeled_image_path=template_synthetic_img_path,
template_LV_seg_path=unlabeled_synthetic_mask_path,
unlabeled_LV_seg_path=template_synthetic_mask_path,
template_shell_seg_path=unlabeled_synthetic_extra_mask_path,
unlabeled_shell_seg_path=template_synthetic_extra_mask_path,
two_d_constraints_path=two_d_constraints_path,
flows_gt_path=synthetic_flow_path,
save_mask=True,
args=EasyDict(config_constraints)
)
if "INFER" in MODE:
for_drawing_path = os.path.join(outputs_path, "for_drawing")
os.makedirs(for_drawing_path, exist_ok=True)
handle_pred_flow(config.fourD_ct_cost_unrolling.w_constraints.scale_down_by, infer_constraints_model_output_path, for_drawing_path) # shape 3xyz
shutil.copyfile(synthetic_flow_path, os.path.join(for_drawing_path,"ground_truth_flow.npy"))
shutil.copyfile(template_synthetic_img_path, os.path.join(for_drawing_path,"template_img.npy"))
shutil.copyfile(unlabeled_synthetic_img_path, os.path.join(for_drawing_path,"unlabeled_img.npy"))
shutil.copyfile(template_synthetic_mask_path, os.path.join(for_drawing_path,"template_mask.npy"))
shutil.copyfile(unlabeled_synthetic_mask_path, os.path.join(for_drawing_path,"unlabeled_mask.npy"))
if "EVAL" in MODE:
errors = {
"errors_backbone": remove_non_floats_from_dict(errors_backbone),
"errors_w_seg": remove_non_floats_from_dict(errors_w_seg),
"errors_w_constraints": remove_non_floats_from_dict(errors_w_constraints)
}
print(f"Writing all errors to {os.path.join(outputs_path, 'errors.json')}")
with open(os.path.join(outputs_path, "errors.json"), "w") as f:
f.write(json.dumps(errors, indent=4))