Skip to content

Commit 57a381b

Browse files
committed
Improve CuPy type hinting in ImageReader
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 57fdd59 commit 57a381b

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

monai/data/image_reader.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Callable, Iterable, Iterator, Sequence
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import TYPE_CHECKING, Any
25+
from typing import TYPE_CHECKING, Any, Union
2626

2727
import numpy as np
2828
from torch.utils.data._utils.collate import np_str_obj_array_pattern
@@ -38,14 +38,17 @@
3838
from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg
3939

4040
if TYPE_CHECKING:
41+
import cupy as cp
4142
import itk
4243
import nibabel as nib
4344
import nrrd
4445
import pydicom
4546
from nibabel.nifti1 import Nifti1Image
4647
from PIL import Image as PILImage
4748

48-
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
49+
has_nrrd = has_itk = has_nib = has_pil = has_pydicom, has_cp = True
50+
Ndarray = Union[np.ndarray, cp.ndarray]
51+
4952
else:
5053
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
5154
nib, has_nib = optional_import("nibabel")
@@ -54,7 +57,12 @@
5457
pydicom, has_pydicom = optional_import("pydicom")
5558
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
5659

57-
cp, has_cp = optional_import("cupy")
60+
cp, has_cp = optional_import("cupy")
61+
if has_cp:
62+
Ndarray = Union[np.ndarray, cp.ndarray]
63+
else:
64+
Ndarray = np.ndarray
65+
5866
kvikio, has_kvikio = optional_import("kvikio")
5967

6068
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
@@ -107,7 +115,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
107115
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
108116

109117
@abstractmethod
110-
def get_data(self, img) -> tuple[np.ndarray, dict]:
118+
def get_data(self, img) -> tuple[Ndarray, dict]:
111119
"""
112120
Extract data array and metadata from loaded image and return them.
113121
This function must return two objects, the first is a numpy array of image data,
@@ -143,7 +151,7 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
143151
)
144152

145153

146-
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
154+
def _stack_images(image_list: list[Ndarray], meta_dict: dict, to_cupy: bool = False) -> Ndarray:
147155
if len(image_list) <= 1:
148156
return image_list[0]
149157
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
@@ -269,7 +277,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
269277
img_.append(itk.imread(name, **kwargs_))
270278
return img_ if len(filenames) > 1 else img_[0]
271279

272-
def get_data(self, img) -> tuple[np.ndarray, dict]:
280+
def get_data(self, img) -> tuple[Ndarray, dict]:
273281
"""
274282
Extract data array and metadata from loaded image and return them.
275283
This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -281,7 +289,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
281289
img: an ITK image object loaded from an image file or a list of ITK image objects.
282290
283291
"""
284-
img_array: list[np.ndarray] = []
292+
img_array: list[Ndarray] = []
285293
compatible_meta: dict = {}
286294

287295
for i in ensure_tuple(img):
@@ -616,7 +624,7 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
616624

617625
return stack_array, stack_metadata
618626

619-
def get_data(self, data) -> tuple[np.ndarray, dict]:
627+
def get_data(self, data) -> tuple[Ndarray, dict]:
620628
"""
621629
Extract data array and metadata from loaded image and return them.
622630
This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -663,10 +671,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
663671
metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape
664672
dicom_data.append((data_array, metadata))
665673

666-
# TODO: the actual type is list[np.ndarray | cp.ndarray]
667-
# should figure out how to define correct types without having cupy not found error
668-
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
669-
img_array: list[np.ndarray] = []
674+
img_array: list[Ndarray] = []
670675
compatible_meta: dict = {}
671676

672677
for data_array, metadata in ensure_tuple(dicom_data):
@@ -841,7 +846,7 @@ def _get_seg_data(self, img, filename):
841846
if self.label_dict is not None:
842847
metadata["labels"] = self.label_dict
843848
if self.to_gpu:
844-
all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
849+
all_segs: Ndarray = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
845850
else:
846851
all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
847852
else:
@@ -899,7 +904,7 @@ def _get_seg_data(self, img, filename):
899904

900905
return all_segs, metadata
901906

902-
def _get_array_data_from_gpu(self, img, filename):
907+
def _get_array_data_from_gpu(self, img, filename) -> Ndarray:
903908
"""
904909
Get the raw array data of the image. This function is used when `to_gpu` is set to True.
905910
@@ -954,7 +959,7 @@ def _get_array_data_from_gpu(self, img, filename):
954959

955960
return data
956961

957-
def _get_array_data(self, img, filename):
962+
def _get_array_data(self, img, filename) -> Ndarray:
958963
"""
959964
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
960965
will be rescaled. The output data has the dtype float32 if the rescaling is applied.
@@ -1092,7 +1097,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
10921097
img_.append(img) # type: ignore
10931098
return img_ if len(filenames) > 1 else img_[0]
10941099

1095-
def get_data(self, img) -> tuple[np.ndarray, dict]:
1100+
def get_data(self, img) -> tuple[Ndarray, dict]:
10961101
"""
10971102
Extract data array and metadata from loaded image and return them.
10981103
This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1104,10 +1109,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
11041109
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
11051110
11061111
"""
1107-
# TODO: the actual type is list[np.ndarray | cp.ndarray]
1108-
# should figure out how to define correct types without having cupy not found error
1109-
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
1110-
img_array: list[np.ndarray] = []
1112+
img_array: list[Ndarray] = []
11111113
compatible_meta: dict = {}
11121114

11131115
for i, filename in zip(ensure_tuple(img), self.filenames):
@@ -1186,7 +1188,7 @@ def _get_spatial_shape(self, img):
11861188
spatial_rank = max(min(ndim, 3), 1)
11871189
return np.asarray(size[:spatial_rank])
11881190

1189-
def _get_array_data(self, img, filename):
1191+
def _get_array_data(self, img, filename) -> Ndarray:
11901192
"""
11911193
Get the raw array data of the image, converted to Numpy array.
11921194
@@ -1281,7 +1283,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
12811283

12821284
return img_ if len(img_) > 1 else img_[0]
12831285

1284-
def get_data(self, img) -> tuple[np.ndarray, dict]:
1286+
def get_data(self, img) -> tuple[Ndarray, dict]:
12851287
"""
12861288
Extract data array and metadata from loaded image and return them.
12871289
This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1293,7 +1295,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
12931295
img: a Numpy array loaded from a file or a list of Numpy arrays.
12941296
12951297
"""
1296-
img_array: list[np.ndarray] = []
1298+
img_array: list[Ndarray] = []
12971299
compatible_meta: dict = {}
12981300
if isinstance(img, np.ndarray):
12991301
img = (img,)
@@ -1374,7 +1376,7 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):
13741376

13751377
return img_ if len(filenames) > 1 else img_[0]
13761378

1377-
def get_data(self, img) -> tuple[np.ndarray, dict]:
1379+
def get_data(self, img) -> tuple[Ndarray, dict]:
13781380
"""
13791381
Extract data array and metadata from loaded image and return them.
13801382
This function returns two objects, first is numpy array of image data, second is dict of metadata.
@@ -1388,7 +1390,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
13881390
img: a PIL Image object loaded from a file or a list of PIL Image objects.
13891391
13901392
"""
1391-
img_array: list[np.ndarray] = []
1393+
img_array: list[Ndarray] = []
13921394
compatible_meta: dict = {}
13931395

13941396
for i in ensure_tuple(img):
@@ -1425,7 +1427,7 @@ def _get_spatial_shape(self, img):
14251427
class NrrdImage:
14261428
"""Class to wrap nrrd image array and metadata header"""
14271429

1428-
array: np.ndarray
1430+
array: Ndarray
14291431
header: dict
14301432

14311433

@@ -1495,7 +1497,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
14951497
img_.append(nrrd_image)
14961498
return img_ if len(filenames) > 1 else img_[0]
14971499

1498-
def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
1500+
def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[Ndarray, dict]:
14991501
"""
15001502
Extract data array and metadata from loaded image and return them.
15011503
This function must return two objects, the first is a numpy array of image data,
@@ -1505,7 +1507,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
15051507
img: a `NrrdImage` loaded from an image file or a list of image objects.
15061508
15071509
"""
1508-
img_array: list[np.ndarray] = []
1510+
img_array: list[Ndarray] = []
15091511
compatible_meta: dict = {}
15101512

15111513
for i in ensure_tuple(img):

0 commit comments

Comments
 (0)