2222from collections .abc import Callable , Iterable , Iterator , Sequence
2323from dataclasses import dataclass
2424from pathlib import Path
25- from typing import TYPE_CHECKING , Any
25+ from typing import TYPE_CHECKING , Any , Union
2626
2727import numpy as np
2828from torch .utils .data ._utils .collate import np_str_obj_array_pattern
3838from monai .utils import MetaKeys , SpaceKeys , TraceKeys , ensure_tuple , optional_import , require_pkg
3939
4040if TYPE_CHECKING :
41+ import cupy as cp
4142 import itk
4243 import nibabel as nib
4344 import nrrd
4445 import pydicom
4546 from nibabel .nifti1 import Nifti1Image
4647 from PIL import Image as PILImage
4748
48- has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
49+ has_nrrd = has_itk = has_nib = has_pil = has_pydicom , has_cp = True
50+ Ndarray = Union [np .ndarray , cp .ndarray ]
51+
4952else :
5053 itk , has_itk = optional_import ("itk" , allow_namespace_pkg = True )
5154 nib , has_nib = optional_import ("nibabel" )
5457 pydicom , has_pydicom = optional_import ("pydicom" )
5558 nrrd , has_nrrd = optional_import ("nrrd" , allow_namespace_pkg = True )
5659
57- cp , has_cp = optional_import ("cupy" )
60+ cp , has_cp = optional_import ("cupy" )
61+ if has_cp :
62+ Ndarray = Union [np .ndarray , cp .ndarray ]
63+ else :
64+ Ndarray = np .ndarray
65+
5866kvikio , has_kvikio = optional_import ("kvikio" )
5967
6068__all__ = ["ImageReader" , "ITKReader" , "NibabelReader" , "NumpyReader" , "PILReader" , "PydicomReader" , "NrrdReader" ]
@@ -107,7 +115,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
107115 raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
108116
109117 @abstractmethod
110- def get_data (self , img ) -> tuple [np . ndarray , dict ]:
118+ def get_data (self , img ) -> tuple [Ndarray , dict ]:
111119 """
112120 Extract data array and metadata from loaded image and return them.
113121 This function must return two objects, the first is a numpy array of image data,
@@ -143,7 +151,7 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
143151 )
144152
145153
146- def _stack_images (image_list : list , meta_dict : dict , to_cupy : bool = False ):
154+ def _stack_images (image_list : list [ Ndarray ] , meta_dict : dict , to_cupy : bool = False ) -> Ndarray :
147155 if len (image_list ) <= 1 :
148156 return image_list [0 ]
149157 if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
@@ -269,7 +277,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
269277 img_ .append (itk .imread (name , ** kwargs_ ))
270278 return img_ if len (filenames ) > 1 else img_ [0 ]
271279
272- def get_data (self , img ) -> tuple [np . ndarray , dict ]:
280+ def get_data (self , img ) -> tuple [Ndarray , dict ]:
273281 """
274282 Extract data array and metadata from loaded image and return them.
275283 This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -281,7 +289,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
281289 img: an ITK image object loaded from an image file or a list of ITK image objects.
282290
283291 """
284- img_array : list [np . ndarray ] = []
292+ img_array : list [Ndarray ] = []
285293 compatible_meta : dict = {}
286294
287295 for i in ensure_tuple (img ):
@@ -616,7 +624,7 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
616624
617625 return stack_array , stack_metadata
618626
619- def get_data (self , data ) -> tuple [np . ndarray , dict ]:
627+ def get_data (self , data ) -> tuple [Ndarray , dict ]:
620628 """
621629 Extract data array and metadata from loaded image and return them.
622630 This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -663,10 +671,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
663671 metadata [MetaKeys .SPATIAL_SHAPE ] = data_array .shape
664672 dicom_data .append ((data_array , metadata ))
665673
666- # TODO: the actual type is list[np.ndarray | cp.ndarray]
667- # should figure out how to define correct types without having cupy not found error
668- # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
669- img_array : list [np .ndarray ] = []
674+ img_array : list [Ndarray ] = []
670675 compatible_meta : dict = {}
671676
672677 for data_array , metadata in ensure_tuple (dicom_data ):
@@ -841,7 +846,7 @@ def _get_seg_data(self, img, filename):
841846 if self .label_dict is not None :
842847 metadata ["labels" ] = self .label_dict
843848 if self .to_gpu :
844- all_segs = cp .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
849+ all_segs : Ndarray = cp .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
845850 else :
846851 all_segs = np .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
847852 else :
@@ -899,7 +904,7 @@ def _get_seg_data(self, img, filename):
899904
900905 return all_segs , metadata
901906
902- def _get_array_data_from_gpu (self , img , filename ):
907+ def _get_array_data_from_gpu (self , img , filename ) -> Ndarray :
903908 """
904909 Get the raw array data of the image. This function is used when `to_gpu` is set to True.
905910
@@ -954,7 +959,7 @@ def _get_array_data_from_gpu(self, img, filename):
954959
955960 return data
956961
957- def _get_array_data (self , img , filename ):
962+ def _get_array_data (self , img , filename ) -> Ndarray :
958963 """
959964 Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
960965 will be rescaled. The output data has the dtype float32 if the rescaling is applied.
@@ -1092,7 +1097,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
10921097 img_ .append (img ) # type: ignore
10931098 return img_ if len (filenames ) > 1 else img_ [0 ]
10941099
1095- def get_data (self , img ) -> tuple [np . ndarray , dict ]:
1100+ def get_data (self , img ) -> tuple [Ndarray , dict ]:
10961101 """
10971102 Extract data array and metadata from loaded image and return them.
10981103 This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1104,10 +1109,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
11041109 img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
11051110
11061111 """
1107- # TODO: the actual type is list[np.ndarray | cp.ndarray]
1108- # should figure out how to define correct types without having cupy not found error
1109- # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
1110- img_array : list [np .ndarray ] = []
1112+ img_array : list [Ndarray ] = []
11111113 compatible_meta : dict = {}
11121114
11131115 for i , filename in zip (ensure_tuple (img ), self .filenames ):
@@ -1186,7 +1188,7 @@ def _get_spatial_shape(self, img):
11861188 spatial_rank = max (min (ndim , 3 ), 1 )
11871189 return np .asarray (size [:spatial_rank ])
11881190
1189- def _get_array_data (self , img , filename ):
1191+ def _get_array_data (self , img , filename ) -> Ndarray :
11901192 """
11911193 Get the raw array data of the image, converted to Numpy array.
11921194
@@ -1281,7 +1283,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
12811283
12821284 return img_ if len (img_ ) > 1 else img_ [0 ]
12831285
1284- def get_data (self , img ) -> tuple [np . ndarray , dict ]:
1286+ def get_data (self , img ) -> tuple [Ndarray , dict ]:
12851287 """
12861288 Extract data array and metadata from loaded image and return them.
12871289 This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1293,7 +1295,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
12931295 img: a Numpy array loaded from a file or a list of Numpy arrays.
12941296
12951297 """
1296- img_array : list [np . ndarray ] = []
1298+ img_array : list [Ndarray ] = []
12971299 compatible_meta : dict = {}
12981300 if isinstance (img , np .ndarray ):
12991301 img = (img ,)
@@ -1374,7 +1376,7 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):
13741376
13751377 return img_ if len (filenames ) > 1 else img_ [0 ]
13761378
1377- def get_data (self , img ) -> tuple [np . ndarray , dict ]:
1379+ def get_data (self , img ) -> tuple [Ndarray , dict ]:
13781380 """
13791381 Extract data array and metadata from loaded image and return them.
13801382 This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1388,7 +1390,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
13881390 img: a PIL Image object loaded from a file or a list of PIL Image objects.
13891391
13901392 """
1391- img_array : list [np . ndarray ] = []
1393+ img_array : list [Ndarray ] = []
13921394 compatible_meta : dict = {}
13931395
13941396 for i in ensure_tuple (img ):
@@ -1425,7 +1427,7 @@ def _get_spatial_shape(self, img):
14251427class NrrdImage :
14261428 """Class to wrap nrrd image array and metadata header"""
14271429
1428- array : np . ndarray
1430+ array : Ndarray
14291431 header : dict
14301432
14311433
@@ -1495,7 +1497,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
14951497 img_ .append (nrrd_image )
14961498 return img_ if len (filenames ) > 1 else img_ [0 ]
14971499
1498- def get_data (self , img : NrrdImage | list [NrrdImage ]) -> tuple [np . ndarray , dict ]:
1500+ def get_data (self , img : NrrdImage | list [NrrdImage ]) -> tuple [Ndarray , dict ]:
14991501 """
15001502 Extract data array and metadata from loaded image and return them.
15011503 This function must return two objects, the first is a numpy array of image data,
@@ -1505,7 +1507,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
15051507 img: a `NrrdImage` loaded from an image file or a list of image objects.
15061508
15071509 """
1508- img_array : list [np . ndarray ] = []
1510+ img_array : list [Ndarray ] = []
15091511 compatible_meta : dict = {}
15101512
15111513 for i in ensure_tuple (img ):
0 commit comments