diff --git a/Wrapping/Generators/Python/Tests/extras.py b/Wrapping/Generators/Python/Tests/extras.py index c0601e5ef2d..584a3cf7abe 100644 --- a/Wrapping/Generators/Python/Tests/extras.py +++ b/Wrapping/Generators/Python/Tests/extras.py @@ -655,3 +655,76 @@ def assert_images_equal(base, test): except ImportError: print("vtk not imported. Skipping vtk conversion tests") pass + +# Test SimpleITK to ITK conversion +try: + import SimpleITK as sitk + + print("Testing SimpleITK conversion") + + # Test 2D scalar image + sitk_image = sitk.Image([12, 8], sitk.sitkFloat32) + sitk_image.SetSpacing([0.5, 1.5]) + sitk_image.SetOrigin([10.0, 20.0]) + theta = np.radians(30) + cosine = np.cos(theta) + sine = np.sin(theta) + sitk_image.SetDirection([cosine, -sine, sine, cosine]) + # Fill with known data + arr_2d = np.random.rand(8, 12).astype(np.float32) + for y in range(8): + for x in range(12): + sitk_image[x, y] = float(arr_2d[y, x]) + + itk_image = itk.image_from_simpleitk(sitk_image) + assert np.array_equal( + np.array(sitk_image.GetSpacing()), np.array(itk.spacing(itk_image)) + ) + assert np.array_equal( + np.array(sitk_image.GetOrigin()), np.array(itk.origin(itk_image)) + ) + sitk_dir = np.array(sitk_image.GetDirection()).reshape(2, 2) + itk_dir = itk.array_from_matrix(itk_image.GetDirection()) + assert np.allclose(sitk_dir, itk_dir) + assert np.array_equal( + sitk.GetArrayFromImage(sitk_image), itk.array_from_image(itk_image) + ) + + # Test 3D scalar image + sitk_3d = sitk.Image([4, 6, 8], sitk.sitkFloat32) + sitk_3d.SetSpacing([1.0, 2.0, 3.0]) + sitk_3d.SetOrigin([5.0, 10.0, 15.0]) + itk_3d = itk.image_from_simpleitk(sitk_3d) + assert np.array_equal(np.array(sitk_3d.GetSpacing()), np.array(itk.spacing(itk_3d))) + assert np.array_equal(np.array(sitk_3d.GetOrigin()), np.array(itk.origin(itk_3d))) + assert itk_3d.GetImageDimension() == 3 + + # Test vector image (multi-component -> VectorImage) + sitk_vector = sitk.Image([10, 10], sitk.sitkVectorFloat32, 3) + sitk_vector.SetSpacing([0.8, 1.2]) + itk_vector = itk.image_from_simpleitk(sitk_vector) + assert itk_vector.GetNumberOfComponentsPerPixel() == 3 + assert isinstance(itk_vector, itk.VectorImage[itk.F, 2]) + assert np.array_equal( + np.array(sitk_vector.GetSpacing()), np.array(itk.spacing(itk_vector)) + ) + + # Test MetaDataDictionary preservation + sitk_meta = sitk.Image([4, 4], sitk.sitkFloat32) + sitk_meta.SetMetaData("test_key", "test_value") + sitk_meta.SetMetaData("0010|0010", "patient_name") + itk_meta = itk.image_from_simpleitk(sitk_meta) + meta_dict = itk_meta.GetMetaDataDictionary() + assert meta_dict["test_key"] == "test_value" + assert meta_dict["0010|0010"] == "patient_name" + + # Test auto-detection in filter function + sitk_input = sitk.Image([32, 32], sitk.sitkFloat32) + result = itk.median_image_filter(sitk_input, radius=1) + assert isinstance(result, itk.Image[itk.F, 2]) + + print("SimpleITK conversion tests passed") + +except ImportError: + print("SimpleITK not imported. Skipping SimpleITK conversion tests") + pass diff --git a/Wrapping/Generators/Python/itk/support/extras.py b/Wrapping/Generators/Python/itk/support/extras.py index a01f9a3f3b9..a6b1f4faf39 100644 --- a/Wrapping/Generators/Python/itk/support/extras.py +++ b/Wrapping/Generators/Python/itk/support/extras.py @@ -95,6 +95,7 @@ "image_from_xarray", "vtk_image_from_image", "image_from_vtk_image", + "image_from_simpleitk", "dict_from_image", "image_from_dict", "image_intensity_min_max", @@ -790,6 +791,88 @@ def image_from_vtk_image(vtk_image: vtk.vtkImageData) -> itkt.ImageBase: return l_image +def image_from_simpleitk(sitk_image) -> "itkt.ImageBase": + """Convert a SimpleITK Image to an itk.Image. + + The source image is accessed through generic interfaces: the NumPy + ``__array__`` protocol for pixel data, ``keys()`` for available + metadata keys, and dictionary-style ``[]`` access for metadata + values. The recognised spatial keys are ``'spacing'``, + ``'origin'``, and ``'direction'``; all other keys returned by + ``keys()`` are copied into the ITK MetaDataDictionary. + + This makes the function forward-compatible with other image + libraries that expose the same conventions (e.g. SimpleITK). + + Pixel data is copied once (SimpleITK buffers are read-only). + Multi-component images are converted to itk.VectorImage. + + Parameters + ---------- + sitk_image : + A SimpleITK Image object (or any object that supports + ``__array__``, ``keys()``, and dictionary ``[]`` access). + + Returns + ------- + image : + The resulting itk.Image (or itk.VectorImage for multi-component pixels). + """ + import itk + + array = np.array(sitk_image) + dim = array.ndim + + spatial_keys = {"spacing", "origin", "direction"} + + if hasattr(sitk_image, "keys"): + keys = list(sitk_image.keys()) + else: + # Probe spatial keys directly (e.g. SimpleITK 3.x supports [] but not keys()) + keys = [] + for k in spatial_keys: + try: + sitk_image[k] + keys.append(k) + except (KeyError, TypeError, AttributeError): + pass + # Collect MetaData keys via dedicated accessor if available + if hasattr(sitk_image, "GetMetaDataKeys"): + keys.extend(sitk_image.GetMetaDataKeys()) + + if "spacing" in keys: + spacing = sitk_image["spacing"] + dim = len(spacing) + + is_vector = array.ndim > dim + + if is_vector: + PixelType = _get_itk_pixelid(array) + ImageType = itk.VectorImage[PixelType, dim] + l_image = itk.image_view_from_array(array, ttype=ImageType) + else: + l_image = itk.image_view_from_array(array) + + if "spacing" in keys: + l_image.SetSpacing(spacing) + + if "origin" in keys: + l_image.SetOrigin(sitk_image["origin"]) + + if "direction" in keys: + direction = np.array(sitk_image["direction"]).reshape(dim, dim) + l_image.SetDirection(direction) + + for key in keys: + if key not in spatial_keys: + l_image.GetMetaDataDictionary()[key] = sitk_image[key] + + # Keep a reference to the numpy array to prevent garbage collection + l_image._SetBase(array) + + return l_image + + def dict_from_image(image: itkt.Image) -> dict: """Serialize a Python itk.Image object to a pickable Python dictionary.""" import itk diff --git a/Wrapping/Generators/Python/itk/support/helpers.py b/Wrapping/Generators/Python/itk/support/helpers.py index 3d271d141ee..8cda02fdbe6 100644 --- a/Wrapping/Generators/Python/itk/support/helpers.py +++ b/Wrapping/Generators/Python/itk/support/helpers.py @@ -38,6 +38,13 @@ _HAVE_TORCH = True except importlib.metadata.PackageNotFoundError: pass +_HAVE_SIMPLEITK = False +try: + metadata("SimpleITK") + + _HAVE_SIMPLEITK = True +except (ImportError, importlib.metadata.PackageNotFoundError): + pass def snake_to_camel_case(keyword: str): @@ -92,11 +99,13 @@ def move_last_dimension_to_first(arr): def accept_array_like_xarray_torch(image_filter): """Decorator that allows itk.ProcessObject snake_case functions to accept - NumPy array-like, PyTorch Tensor's or xarray DataArray inputs for itk.Image inputs. + NumPy array-like, PyTorch Tensor's, xarray DataArray, or SimpleITK Image + inputs for itk.Image inputs. If a NumPy array-like is passed as an input, output itk.Image's are converted to numpy.ndarray's. If a torch.Tensor is passed as an input, output itk.Image's are converted to torch.Tensors. If a xarray DataArray is passed as an input, output itk.Image's are converted to xarray.DataArray's. + If a SimpleITK Image is passed as an input, output itk.Image's are returned as-is. """ import numpy as np import itk @@ -105,16 +114,23 @@ def accept_array_like_xarray_torch(image_filter): import xarray as xr if _HAVE_TORCH: import torch + if _HAVE_SIMPLEITK: + import SimpleITK as sitk @functools.wraps(image_filter) def image_filter_wrapper(*args, **kwargs): have_array_input = False have_xarray_input = False have_torch_input = False + have_simpleitk_input = False args_list = list(args) for index, arg in enumerate(args): - if _HAVE_XARRAY and isinstance(arg, xr.DataArray): + if _HAVE_SIMPLEITK and isinstance(arg, sitk.Image): + have_simpleitk_input = True + image = itk.image_from_simpleitk(arg) + args_list[index] = image + elif _HAVE_XARRAY and isinstance(arg, xr.DataArray): have_xarray_input = True image = itk.image_from_xarray(arg) args_list[index] = image @@ -135,7 +151,11 @@ def image_filter_wrapper(*args, **kwargs): potential_image_input_kwargs = ("input", "input1", "input2", "input3") for key, value in kwargs.items(): if key.lower() in potential_image_input_kwargs or "image" in key.lower(): - if _HAVE_XARRAY and isinstance(value, xr.DataArray): + if _HAVE_SIMPLEITK and isinstance(value, sitk.Image): + have_simpleitk_input = True + image = itk.image_from_simpleitk(value) + kwargs[key] = image + elif _HAVE_XARRAY and isinstance(value, xr.DataArray): have_xarray_input = True image = itk.image_from_xarray(value) kwargs[key] = image @@ -155,9 +175,16 @@ def image_filter_wrapper(*args, **kwargs): image = itk.image_view_from_array(array) kwargs[key] = image - if have_xarray_input or have_torch_input or have_array_input: - # Convert output itk.Image's to numpy.ndarray's + if ( + have_simpleitk_input + or have_xarray_input + or have_torch_input + or have_array_input + ): + # Convert output itk.Image's based on input type output = image_filter(*tuple(args_list), **kwargs) + if have_simpleitk_input: + return output if isinstance(output, tuple): output_list = list(output) for index, value in enumerate(output_list):