diff --git a/run_dipy_gpu.py b/run_dipy_gpu.py index ca51cba..b417bd9 100644 --- a/run_dipy_gpu.py +++ b/run_dipy_gpu.py @@ -85,10 +85,10 @@ help="default: 0.5") args = parser.parse_args() -img = load_nifti(args.nifti_file, return_img=True) +img_data, img_affine, img = load_nifti(args.nifti_file, return_img=True) voxel_order = "".join(aff2axcodes(img.affine)) gtab = gradient_table(args.bvals, args.bvecs) -mask = load_nifti(args.mask_nifti, return_img=True) +mask_data, mask_affine, mask = load_nifti(args.mask_nifti, return_img=True) data = img.get_fdata() # resample mask if necessary @@ -103,8 +103,14 @@ mask = mask.get_fdata() # load or compute and save FA file -if (args.fa_numpy is not None) and os.path.isfile(args.fa_numpy): - FA = np.load(args.fa_numpy, allow_pickle=True) +if (args.fa_file is not None) and os.path.isfile(args.fa_file): + _, fa_extension = os.path.splitext(filename) + if fa_extension in ['.npy', '.npz', '.pkl']: + FA = np.load(args.fa_file, allow_pickle=True) + elif fa_extension in ['.nii','.gz']: + FA, FA_affine = load_nifti(args.fa_file) + else: + raise TypeError('FA filename is not one of the supported format (.npy, .npz, .pkl, .nii, .gz).') else: # Fit tenmodel = dti.TensorModel(gtab, fit_method='WLS') @@ -114,15 +120,15 @@ FA = tenfit.fa FA[np.isnan(FA)] = 0 - if args.fa_numpy is not None: - np.save(args.fa_numpy, FA) + if args.fa_file is not None: + np.save(args.fa_file, FA) # Setup tissue_classifier args metric_map = np.asarray(FA, 'float64') # resample roi if necessary if args.roi_nifti is not None: - roi_data, roi = load_nifti(args.roi_nifti, + roi_data, roi_affine, roi = load_nifti(args.roi_nifti, return_img=True, as_ndarray=True) else: