Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions nion/data/Core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,30 +1801,48 @@ def calculate_data() -> _ImageDataType:


def function_warp(data_and_metadata_in: _DataAndMetadataLike, coordinates_in: typing.Sequence[_DataAndMetadataLike], order: int = 1) -> DataAndMetadata.DataAndMetadata:
"""Warp or unwarp input data using an N-dimensional warp map.

The warp map is applied along N axes and broadcast over any additional
dimensions in the input, allowing a single warp map to be used for
higher-dimensional data (e.g., image sequences). For multichannel data
such as RGB/RGBA, the warp is applied uniformly to all channels.
"""
Comment thread
Tiomat85 marked this conversation as resolved.
data_and_metadata = DataAndMetadata.promote_ndarray(data_and_metadata_in)
coordinates = [DataAndMetadata.promote_ndarray(c) for c in coordinates_in]
coords = numpy.moveaxis(numpy.dstack([coordinate.data for coordinate in coordinates]), -1, 0)
coords = numpy.stack([c.data.astype(float) for c in coordinates], axis=0)
data = data_and_metadata._data_ex
if data_and_metadata.is_data_rgb:
rgb: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (3,), numpy.uint8)
rgb[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order)
rgb[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order)
rgb[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order)
return DataAndMetadata.new_data_and_metadata(data=rgb,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)
elif data_and_metadata.is_data_rgba:
rgba: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (4,), numpy.uint8)
rgba[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order)
rgba[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order)
rgba[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order)
rgba[..., 3] = scipy.ndimage.map_coordinates(data[..., 3], coords, order=order)
return DataAndMetadata.new_data_and_metadata(data=rgba,
num_frame_dims = coords.shape[0]

if data_and_metadata.is_data_rgb_type:
# Last dimension is channels
leading_shape = data.shape[:-num_frame_dims - 1]
output_shape = leading_shape + coords.shape[1:]
channels = 3 if data_and_metadata.is_data_rgb else 4
output = numpy.zeros(tuple(output_shape) + (channels,), numpy.uint8)

# scipy map_coordinates does not broadcast by default, so need to loop
for index in numpy.ndindex(leading_shape):
for chan in range(channels):
output[index + (..., chan)] = scipy.ndimage.map_coordinates(
data[index + (..., chan)],
coords,
order=order)

return DataAndMetadata.new_data_and_metadata(data=output,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)
else:
leading_shape = data.shape[:-num_frame_dims]
output_shape = leading_shape + coords.shape[1:]
output = numpy.zeros(output_shape, dtype=data.dtype)

# scipy map_coordinates does not broadcast by default, so need to loop
for index in numpy.ndindex(leading_shape):
output[index] = scipy.ndimage.map_coordinates(data[index], coords, order=order)

return DataAndMetadata.new_data_and_metadata(
data=scipy.ndimage.map_coordinates(data, coords, order=order),
data=output,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)

Expand Down
95 changes: 95 additions & 0 deletions nion/data/test/Core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,101 @@ def test_fft_zero_component_calibration(self) -> None:
result4 = Core.function_fft(xdata4)
self.assertAlmostEqual(0.0, result4.dimensional_calibrations[0].convert_to_calibrated_value(7.5))

## WARP TESTS
# Helper func
def _create_warp_test_data(self,
input_shape: tuple[int,...],
output_shape: tuple[int, ...] | None = None,
identity: bool = False,
mode: str = "greyscale") -> tuple[DataAndMetadata.DataAndMetadata, list[numpy.ndarray]]:
# Determine data type and channels based on mode
dtype: numpy.typing.DTypeLike
if mode == "greyscale":
dtype = float
channels = None
elif mode == "rgb":
dtype = numpy.uint8
channels = 3
elif mode == "rgba":
dtype = numpy.uint8
channels = 4
else:
raise ValueError(f"Invalid mode: {mode}. Choose 'greyscale', 'rgb', or 'rgba'.")

# Prepare input shape for data array
if channels is None:
full_shape = input_shape
else:
full_shape = input_shape + (channels,)

# Input data: sequential numbers for easy validation
data = numpy.arange(numpy.prod(full_shape), dtype=dtype).reshape(full_shape)
src = DataAndMetadata.new_data_and_metadata(data=data)

# Determine output grid shape
if output_shape is None:
height, width = input_shape[-2:]
else:
height, width = output_shape[-2:]

# Create warp coordinates
if identity:
# Identity warp: map output coordinates to same as input indices
warp_y, warp_x = numpy.meshgrid(
numpy.arange(input_shape[-2]),
numpy.arange(input_shape[-1]),
indexing="ij"
)
else:
# Resampling / scaling: map output grid into input index space
input_height, input_width = input_shape[-2:]
y = numpy.arange(0, input_height, input_height / height)
x = numpy.arange(0, input_width, input_width / width)
warp_y, warp_x = numpy.meshgrid(y, x, indexing="ij")

return src, [warp_y, warp_x]

def _validate_warp_shape(self, src: DataAndMetadata.DataAndMetadata, dst: DataAndMetadata.DataAndMetadata, coords: list[numpy.ndarray], is_channel_data: bool = False) -> None:
n_dims = len(coords) # number of warped dimensions
output_shape = coords[0].shape # shape of warp grid
expected_shape = src.data_shape[:-n_dims] + output_shape

if is_channel_data:
expected_shape = src.data_shape[:-n_dims-1] + output_shape + (src.data_shape[-1],)

assert dst.data_shape == expected_shape, f"Output shape mismatch: {dst.data_shape} != {expected_shape}"

def test_warp_identity(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(4, 4), identity=True)
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords)

def test_warp_sequence(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4))
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords)

def test_warp_upscale(self) -> None:
# Input 4x4, warp to 8x8
src, coords = self._create_warp_test_data(input_shape=(4, 4), output_shape=(8, 8))
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords)

def test_warp_sequence_upscale(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(6, 8, 8))
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords)

def test_warp_rgb(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgb")
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords, is_channel_data=True)

def test_warp_rgba(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgba")
dst = Core.function_warp(src, coords)
self._validate_warp_shape(src, dst, coords, is_channel_data=True)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
Loading