Skip to content

Commit 70e1978

Browse files
Improve Affine transform documentation and add compute_w_affine tests
- Add Note section documenting center-origin coordinate system assumption - Clarify normalized parameter documentation with user-friendly explanation - Add comprehensive docstring to compute_w_affine method - Add focused unit tests for compute_w_affine (2D/3D identity, different sizes, output shape, torch input compatibility) Fixes #7092 Signed-off-by: Mohamed Salah <eng.mohamed.tawab@gmail.com>
1 parent 2147c11 commit 70e1978

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

monai/transforms/spatial/array.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,13 @@ class Affine(InvertibleTransform, LazyTransform):
21662166
21672167
This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
21682168
for more information.
2169+
2170+
Note:
2171+
This transform assumes that the origin of the coordinate system is at the spatial center
2172+
of the image. When applying transformations (rotation, scaling, etc.), they are performed
2173+
relative to this center point. If you need transformations around a different origin,
2174+
you may need to compose this transform with translation operations or adjust your affine
2175+
matrix accordingly.
21692176
"""
21702177

21712178
backend = list(set(AffineGrid.backend) & set(Resample.backend))
@@ -2228,10 +2235,12 @@ def __init__(
22282235
When `mode` is an integer, using numpy/cupy backends, this argument accepts
22292236
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
22302237
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
2231-
normalized: indicating whether the provided `affine` is defined to include a normalization
2232-
transform converting the coordinates from `[-(size-1)/2, (size-1)/2]` (defined in ``create_grid``) to
2233-
`[0, size - 1]` or `[-1, 1]` in order to be compatible with the underlying resampling API.
2234-
If `normalized=False`, additional coordinate normalization will be applied before resampling.
2238+
normalized: indicates whether the provided `affine` matrix already includes coordinate
2239+
normalization. Set to ``True`` if your affine matrix is designed to work with normalized
2240+
coordinates (e.g., from image processing libraries that use normalized coordinate systems).
2241+
Set to ``False`` (default) if your affine matrix works with pixel/voxel coordinates centered
2242+
at the image center. When ``False``, MONAI will automatically apply the necessary coordinate
2243+
transformations. Most users should use the default ``False``.
22352244
See also: :py:func:`monai.networks.utils.normalize_transform`.
22362245
device: device on which the tensor will be allocated.
22372246
dtype: data type for resampling computation. Defaults to ``float32``.
@@ -2323,6 +2332,24 @@ def __call__(
23232332

23242333
@classmethod
23252334
def compute_w_affine(cls, spatial_rank, mat, img_size, sp_size):
2335+
"""
2336+
Compute the affine matrix for transforming image coordinates, accounting for
2337+
center-based coordinate system.
2338+
2339+
This function adjusts the provided affine transformation matrix to work with images
2340+
where transformations are applied relative to the image center rather than the origin.
2341+
It composes the input matrix with translation operations that shift between
2342+
corner-based and center-based coordinate systems.
2343+
2344+
Args:
2345+
spatial_rank: number of spatial dimensions (e.g., 2 for 2D, 3 for 3D).
2346+
mat: the base affine transformation matrix to be adjusted.
2347+
img_size: spatial dimensions of the input image.
2348+
sp_size: spatial dimensions of the output (transformed) image.
2349+
2350+
Returns:
2351+
The adjusted affine matrix that can be applied to image coordinates.
2352+
"""
23262353
r = int(spatial_rank)
23272354
mat = to_affine_nd(r, mat)
23282355
shift_1 = create_translate(r, [float(d - 1) / 2 for d in img_size[:r]])

tests/transforms/test_affine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,48 @@ def test_affine(self, input_param, input_data, expected_val):
199199
)
200200

201201

202+
class TestComputeWAffine(unittest.TestCase):
203+
def test_identity_2d(self):
204+
"""Identity matrix with same input/output size should produce pure translation to/from center."""
205+
mat = np.eye(3)
206+
img_size = (4, 4)
207+
sp_size = (4, 4)
208+
result = Affine.compute_w_affine(2, mat, img_size, sp_size)
209+
# For identity transform with same sizes, result should be identity
210+
assert_allclose(result, np.eye(3), atol=1e-6)
211+
212+
def test_identity_3d(self):
213+
"""Identity matrix in 3D with same input/output size."""
214+
mat = np.eye(4)
215+
img_size = (6, 6, 6)
216+
sp_size = (6, 6, 6)
217+
result = Affine.compute_w_affine(3, mat, img_size, sp_size)
218+
assert_allclose(result, np.eye(4), atol=1e-6)
219+
220+
def test_different_sizes(self):
221+
"""When img_size != sp_size, result should include net translation."""
222+
mat = np.eye(3)
223+
img_size = (4, 4)
224+
sp_size = (8, 8)
225+
result = Affine.compute_w_affine(2, mat, img_size, sp_size)
226+
# Translation should account for the shift: (4-1)/2 - (8-1)/2 = 1.5 - 3.5 = -2.0
227+
expected_translation = np.array([(d1 - 1) / 2 - (d2 - 1) / 2 for d1, d2 in zip(img_size, sp_size)])
228+
assert_allclose(result[:2, 2], expected_translation, atol=1e-6)
229+
230+
def test_output_shape(self):
231+
"""Output should be (r+1) x (r+1) matrix."""
232+
for r in [2, 3]:
233+
mat = np.eye(r + 1)
234+
result = Affine.compute_w_affine(r, mat, (4,) * r, (4,) * r)
235+
self.assertEqual(result.shape, (r + 1, r + 1))
236+
237+
def test_torch_input(self):
238+
"""Method should accept torch tensor input."""
239+
mat = torch.eye(3)
240+
result = Affine.compute_w_affine(2, mat, (4, 4), (4, 4))
241+
assert_allclose(result, np.eye(3), atol=1e-6)
242+
243+
202244
@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.")
203245
class TestAffineConsistency(unittest.TestCase):
204246
@parameterized.expand([[7], [8], [9]])

0 commit comments

Comments
 (0)