Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,76 @@ def assert_images_equal(base, test):
except ImportError:
print("vtk not imported. Skipping vtk conversion tests")
pass

# Test SimpleITK to ITK conversion
try:
import SimpleITK as sitk

print("Testing SimpleITK conversion")

# Test 2D scalar image
sitk_image = sitk.Image([12, 8], sitk.sitkFloat32)
sitk_image.SetSpacing([0.5, 1.5])
sitk_image.SetOrigin([10.0, 20.0])
theta = np.radians(30)
cosine = np.cos(theta)
sine = np.sin(theta)
sitk_image.SetDirection([cosine, -sine, sine, cosine])
# Fill with known data
arr_2d = np.random.rand(8, 12).astype(np.float32)
for y in range(8):
for x in range(12):
sitk_image[x, y] = float(arr_2d[y, x])

itk_image = itk.image_from_simpleitk(sitk_image)
assert np.array_equal(
np.array(sitk_image.GetSpacing()), np.array(itk.spacing(itk_image))
)
assert np.array_equal(
np.array(sitk_image.GetOrigin()), np.array(itk.origin(itk_image))
)
sitk_dir = np.array(sitk_image.GetDirection()).reshape(2, 2)
itk_dir = itk.array_from_matrix(itk_image.GetDirection())
assert np.allclose(sitk_dir, itk_dir)
assert np.array_equal(
sitk.GetArrayFromImage(sitk_image), itk.array_from_image(itk_image)
)

# Test 3D scalar image
sitk_3d = sitk.Image([4, 6, 8], sitk.sitkFloat32)
sitk_3d.SetSpacing([1.0, 2.0, 3.0])
sitk_3d.SetOrigin([5.0, 10.0, 15.0])
itk_3d = itk.image_from_simpleitk(sitk_3d)
assert np.array_equal(np.array(sitk_3d.GetSpacing()), np.array(itk.spacing(itk_3d)))
assert np.array_equal(np.array(sitk_3d.GetOrigin()), np.array(itk.origin(itk_3d)))
assert itk_3d.GetImageDimension() == 3

# Test vector image (multi-component -> VectorImage)
sitk_vector = sitk.Image([10, 10], sitk.sitkVectorFloat32, 3)
sitk_vector.SetSpacing([0.8, 1.2])
itk_vector = itk.image_from_simpleitk(sitk_vector)
assert itk_vector.GetNumberOfComponentsPerPixel() == 3
assert isinstance(itk_vector, itk.VectorImage[itk.F, 2])
assert np.array_equal(
np.array(sitk_vector.GetSpacing()), np.array(itk.spacing(itk_vector))
)

# Test MetaDataDictionary preservation
sitk_meta = sitk.Image([4, 4], sitk.sitkFloat32)
sitk_meta.SetMetaData("test_key", "test_value")
sitk_meta.SetMetaData("0010|0010", "patient_name")
itk_meta = itk.image_from_simpleitk(sitk_meta)
meta_dict = itk_meta.GetMetaDataDictionary()
assert meta_dict["test_key"] == "test_value"
assert meta_dict["0010|0010"] == "patient_name"

# Test auto-detection in filter function
sitk_input = sitk.Image([32, 32], sitk.sitkFloat32)
result = itk.median_image_filter(sitk_input, radius=1)
assert isinstance(result, itk.Image[itk.F, 2])

print("SimpleITK conversion tests passed")

except ImportError:
print("SimpleITK not imported. Skipping SimpleITK conversion tests")
pass
83 changes: 83 additions & 0 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"image_from_xarray",
"vtk_image_from_image",
"image_from_vtk_image",
"image_from_simpleitk",
"dict_from_image",
"image_from_dict",
"image_intensity_min_max",
Expand Down Expand Up @@ -790,6 +791,88 @@ def image_from_vtk_image(vtk_image: vtk.vtkImageData) -> itkt.ImageBase:
return l_image


def image_from_simpleitk(sitk_image) -> "itkt.ImageBase":
"""Convert a SimpleITK Image to an itk.Image.

The source image is accessed through generic interfaces: the NumPy
``__array__`` protocol for pixel data, ``keys()`` for available
metadata keys, and dictionary-style ``[]`` access for metadata
values. The recognised spatial keys are ``'spacing'``,
``'origin'``, and ``'direction'``; all other keys returned by
``keys()`` are copied into the ITK MetaDataDictionary.

This makes the function forward-compatible with other image
libraries that expose the same conventions (e.g. SimpleITK).

Pixel data is copied once (SimpleITK buffers are read-only).
Multi-component images are converted to itk.VectorImage.

Parameters
----------
sitk_image :
A SimpleITK Image object (or any object that supports
``__array__``, ``keys()``, and dictionary ``[]`` access).

Returns
-------
image :
The resulting itk.Image (or itk.VectorImage for multi-component pixels).
"""
import itk

array = np.array(sitk_image)
dim = array.ndim

spatial_keys = {"spacing", "origin", "direction"}

if hasattr(sitk_image, "keys"):
keys = list(sitk_image.keys())
else:
# Probe spatial keys directly (e.g. SimpleITK 3.x supports [] but not keys())
keys = []
for k in spatial_keys:
try:
sitk_image[k]
keys.append(k)
except (KeyError, TypeError, AttributeError):
pass
# Collect MetaData keys via dedicated accessor if available
if hasattr(sitk_image, "GetMetaDataKeys"):
keys.extend(sitk_image.GetMetaDataKeys())

if "spacing" in keys:
spacing = sitk_image["spacing"]
dim = len(spacing)

is_vector = array.ndim > dim

if is_vector:
PixelType = _get_itk_pixelid(array)
ImageType = itk.VectorImage[PixelType, dim]
l_image = itk.image_view_from_array(array, ttype=ImageType)
else:
l_image = itk.image_view_from_array(array)

if "spacing" in keys:
l_image.SetSpacing(spacing)

if "origin" in keys:
l_image.SetOrigin(sitk_image["origin"])

if "direction" in keys:
direction = np.array(sitk_image["direction"]).reshape(dim, dim)
l_image.SetDirection(direction)

Comment thread
hjmjohnson marked this conversation as resolved.
for key in keys:
if key not in spatial_keys:
l_image.GetMetaDataDictionary()[key] = sitk_image[key]

# Keep a reference to the numpy array to prevent garbage collection
l_image._SetBase(array)

return l_image


def dict_from_image(image: itkt.Image) -> dict:
"""Serialize a Python itk.Image object to a pickable Python dictionary."""
import itk
Expand Down
37 changes: 32 additions & 5 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
_HAVE_TORCH = True
except importlib.metadata.PackageNotFoundError:
pass
_HAVE_SIMPLEITK = False
try:
metadata("SimpleITK")

_HAVE_SIMPLEITK = True
except (ImportError, importlib.metadata.PackageNotFoundError):
pass


def snake_to_camel_case(keyword: str):
Expand Down Expand Up @@ -92,11 +99,13 @@ def move_last_dimension_to_first(arr):

def accept_array_like_xarray_torch(image_filter):
"""Decorator that allows itk.ProcessObject snake_case functions to accept
NumPy array-like, PyTorch Tensor's or xarray DataArray inputs for itk.Image inputs.
NumPy array-like, PyTorch Tensor's, xarray DataArray, or SimpleITK Image
inputs for itk.Image inputs.

If a NumPy array-like is passed as an input, output itk.Image's are converted to numpy.ndarray's.
If a torch.Tensor is passed as an input, output itk.Image's are converted to torch.Tensors.
If a xarray DataArray is passed as an input, output itk.Image's are converted to xarray.DataArray's.
If a SimpleITK Image is passed as an input, output itk.Image's are returned as-is.
"""
import numpy as np
import itk
Expand All @@ -105,16 +114,23 @@ def accept_array_like_xarray_torch(image_filter):
import xarray as xr
if _HAVE_TORCH:
import torch
if _HAVE_SIMPLEITK:
import SimpleITK as sitk

@functools.wraps(image_filter)
def image_filter_wrapper(*args, **kwargs):
have_array_input = False
have_xarray_input = False
have_torch_input = False
have_simpleitk_input = False

args_list = list(args)
for index, arg in enumerate(args):
if _HAVE_XARRAY and isinstance(arg, xr.DataArray):
if _HAVE_SIMPLEITK and isinstance(arg, sitk.Image):
have_simpleitk_input = True
image = itk.image_from_simpleitk(arg)
args_list[index] = image
elif _HAVE_XARRAY and isinstance(arg, xr.DataArray):
have_xarray_input = True
image = itk.image_from_xarray(arg)
args_list[index] = image
Expand All @@ -135,7 +151,11 @@ def image_filter_wrapper(*args, **kwargs):
potential_image_input_kwargs = ("input", "input1", "input2", "input3")
for key, value in kwargs.items():
if key.lower() in potential_image_input_kwargs or "image" in key.lower():
if _HAVE_XARRAY and isinstance(value, xr.DataArray):
if _HAVE_SIMPLEITK and isinstance(value, sitk.Image):
have_simpleitk_input = True
image = itk.image_from_simpleitk(value)
kwargs[key] = image
elif _HAVE_XARRAY and isinstance(value, xr.DataArray):
have_xarray_input = True
image = itk.image_from_xarray(value)
kwargs[key] = image
Expand All @@ -155,9 +175,16 @@ def image_filter_wrapper(*args, **kwargs):
image = itk.image_view_from_array(array)
kwargs[key] = image

if have_xarray_input or have_torch_input or have_array_input:
# Convert output itk.Image's to numpy.ndarray's
if (
have_simpleitk_input
or have_xarray_input
or have_torch_input
or have_array_input
):
# Convert output itk.Image's based on input type
output = image_filter(*tuple(args_list), **kwargs)
if have_simpleitk_input:
return output
if isinstance(output, tuple):
output_list = list(output)
for index, value in enumerate(output_list):
Expand Down
Loading