Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
eac7ea7
Implemented rotate3D. Changed rotate to have the correct rotation dir…
Feb 3, 2025
43bf433
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Feb 3, 2025
a6f72d1
Fix more ruff-linter warnings.
Feb 3, 2025
864799c
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Feb 3, 2025
ad01586
Merge branch 'master' into rotate3D
schuenke Feb 5, 2025
b3c2f92
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Feb 5, 2025
5375568
Merge branch 'master' into rotate3D
schuenke Feb 13, 2025
ef9faac
Moved rotate3D.py and test_rotation3D_vs_rotation.py to the correct f…
Feb 13, 2025
0fefe07
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Feb 13, 2025
233c659
Merge branch 'master' into rotate3D
schuenke Feb 13, 2025
1c82973
Merge branch 'master' into rotate3D
schuenke Jun 2, 2025
a6ec920
split tests and add missing abs() in tolerance check
schuenke Jun 2, 2025
ab9c592
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jun 2, 2025
8ebd657
Merge branch 'imr-framework:master' into rotate3D
mcencini Jun 24, 2025
4be47d2
move Approx to test/conftest.py for reusability across tests. Re-fact…
mcencini Jun 24, 2025
ac096be
fix bug in rotate3D
mcencini Jun 24, 2025
119af14
Remove unwanted print statements in tests
mcencini Jun 24, 2025
0c34875
Remove old comment in test_rotation3D
mcencini Jun 24, 2025
56e3986
Merge branch 'imr-framework:master' into rotate3D
mcencini Jul 31, 2025
c2766f1
Merge branch 'master' into rotate3D
mcencini Jan 14, 2026
1c5b625
Merge branch 'master' into rotate3D
mcencini Jan 16, 2026
41c730a
Merge branch 'master' into rotate3D
mcencini Jan 28, 2026
395e879
Merge branch 'master' into rotate3D
mcencini Jan 29, 2026
a10a070
Merge branch 'master' into rotate3D
mcencini Jan 30, 2026
2eaafb2
Merge branch 'master' into rotate3D
mcencini Jan 30, 2026
fd2827a
Merge branch 'master' into rotate3D
mcencini Feb 25, 2026
99871c3
Merge branch 'master' into rotate3D
mcencini Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pypulseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def round_half_up(n, decimals=0):
from pypulseq.opts import Opts
from pypulseq.points_to_waveform import points_to_waveform
from pypulseq.rotate import rotate
from pypulseq.rotate3D import rotate3D
from pypulseq.scale_grad import scale_grad
from pypulseq.split_gradient import split_gradient
from pypulseq.split_gradient_at import split_gradient_at
Expand Down
16 changes: 15 additions & 1 deletion src/pypulseq/rotate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from types import SimpleNamespace
from typing import List, Union
from warnings import warn

import numpy as np

Expand All @@ -20,7 +21,10 @@ def rotate(*args: SimpleNamespace, angle: float, axis: str, system: Union[Opts,
Rotates the corresponding gradient(s) about the given axis by the specified amount. Gradients parallel to the
rotation axis and non-gradient(s) are not affected. Possible rotation axes are 'x', 'y' or 'z'.

See also `pypulseq.Sequence.sequence.add_block()`.
When using rotate() around the y-axis the rotation direction is reversed compared to previous versions to be consistent with rotate3D().
There is no change in behavior of rotate() for rotations around the x- or z-axis.

See also `pypulseq.rotate3D.rotate3D()` and `pypulseq.Sequence.sequence.add_block()`.

Parameters
----------
Expand Down Expand Up @@ -54,6 +58,16 @@ def rotate(*args: SimpleNamespace, angle: float, axis: str, system: Union[Opts,
if len(axes_to_rotate) != 2:
raise ValueError('Incorrect axes specification.')

if axis == 'y':
warning_message = 'When using rotate() around the y-axis the rotation direction is reversed '
warning_message += 'compared to previous versions to be consistent with rotate3D().'
warning_message += 'There is no change in behavior of rotate() for rotations around the x- or z-axis.'
warn(warning_message, stacklevel=2)
axes_to_rotate = [
axes_to_rotate[1],
axes_to_rotate[0],
] # reverse the list to preserve the correct handiness of the rotation matrix

for i in range(len(args)):
event = args[i]

Expand Down
96 changes: 96 additions & 0 deletions src/pypulseq/rotate3D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from types import SimpleNamespace
from typing import List, Union

import numpy as np

from pypulseq.add_gradients import add_gradients
from pypulseq.opts import Opts
from pypulseq.scale_grad import scale_grad
from pypulseq.utils.tracing import trace, trace_enabled


def __get_grad_abs_mag(grad: SimpleNamespace) -> np.ndarray:
if grad.type == 'trap':
return abs(grad.amplitude)
return np.max(np.abs(grad.waveform))


def rotate3D(
*args: SimpleNamespace, rotation_matrix: np.ndarray[np.float64], system: Union[Opts, None] = None
) -> List[SimpleNamespace]:
"""
Rotates the corresponding gradient(s) by the provided rotation matrix. Non-gradient(s) are not affected.

See also `pypulseq.rotate.rotate()` and `pypulseq.Sequence.sequence.add_block()`.

Parameters
----------
args : SimpleNamespace
Gradient(s).
rotation_matrix : np.ndarray[np.float64]
3x3 rotation matrix by which the gradient(s) are rotated.
system : Opts, default=Opts()
System limits.

Returns
-------
rotated_grads : [SimpleNamespace]
Rotated gradient(s).
"""
if system is None:
system = Opts.default

if rotation_matrix.shape != (3, 3):
raise ValueError('The rotation matrix must have shape (3, 3).')

# First create indexes of the objects to be bypassed or rotated
axes = ['x', 'y', 'z']
events_to_rotate_dict = {}
i_bypass = []

for i in range(len(args)):
event = args[i]
if event.type != 'grad' and event.type != 'trap':
i_bypass.append(i)
else:
if event.channel not in axes:
raise ValueError('Invalid event channel. Expected one of ' + str(axes))
elif event.channel in events_to_rotate_dict:
raise ValueError('More than one gradient for the same channel provided, channel: ' + str(event.channel))
else:
events_to_rotate_dict[event.channel] = event

# Measure of relevant amplitude
max_mag = 0
for axis in axes:
if axis in events_to_rotate_dict:
event = events_to_rotate_dict[axis]
max_mag = max(max_mag, __get_grad_abs_mag(event))
fthresh = 1e-6
thresh = fthresh * max_mag

# Rotate the events (gradients)
rotated_gradients = []
for j in range(3):
grad_out_curr = None
for i in range(3):
if axes[i] not in events_to_rotate_dict or abs(rotation_matrix[j, i]) < fthresh:
continue
scaled_gradient = scale_grad(grad=events_to_rotate_dict[axes[i]], scale=rotation_matrix[j, i])
scaled_gradient.channel = axes[j]
if grad_out_curr is None:
grad_out_curr = scaled_gradient
else:
grad_out_curr = add_gradients((grad_out_curr, scaled_gradient), system=system)
if grad_out_curr is not None and __get_grad_abs_mag(grad_out_curr) >= thresh:
rotated_gradients.append(grad_out_curr)

# Return
bypass = np.take(args, i_bypass)
return_grads = [*bypass, *rotated_gradients]

if trace_enabled():
for grad in return_grads:
grad.trace = trace()

return return_grads
102 changes: 102 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import math
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import numpy.testing as npt
import pytest
from _pytest.python_api import ApproxBase
from scipy.spatial.transform import Rotation as R


# this is currently not used, but might be useful in the future
Expand Down Expand Up @@ -38,3 +43,100 @@ def compare(file1, file2):
assert line1 == line2

return compare


class Approx(ApproxBase):
"""
Fast approximate equality for nested dicts, lists, tuples, SimpleNamespace, and numpy arrays,
derived from pytest's ApproxBase for seamless pytest integration.
"""

def __init__(self, expected, *, rel=1e-6, abs=1e-12, nan_ok=False): # noqa: A002
super().__init__(expected, rel=rel, abs=abs, nan_ok=nan_ok)
self._errors = []

def __eq__(self, actual):
# reset errors
self._errors.clear()
# stack: (path, expected, actual)
stack = [((), self.expected, actual)]
rel_tol, abs_tol, nan_ok, errs = self.rel, self.abs, self.nan_ok, self._errors
isclose = math.isclose

while stack:
path, exp, act = stack.pop()
if isinstance(exp, R):
exp = exp.as_matrix()
if isinstance(act, R):
act = act.as_matrix()

# dict
if isinstance(exp, dict):
if not isinstance(act, dict) or set(exp) != set(act):
errs.append(
f'{".".join(path) or "<root>"}: key-sets differ; expected {set(exp)}, got {set(getattr(act, "keys", lambda: act)())}' # noqa: B023
)
return False
for k in exp:
stack.append((path + (str(k),), exp[k], act[k])) # noqa: RUF005
continue

# list/tuple
if isinstance(exp, (list, tuple)):
if not isinstance(act, type(exp)) or len(exp) != len(act):
errs.append(
f'{".".join(path) or "<root>"}: length/type mismatch; expected {type(exp).__name__}[{len(exp)}], got {type(act).__name__}[{len(act)}]'
)
return False
for idx, (e, a) in enumerate(zip(exp, act)):
stack.append((path + (str(idx),), e, a)) # noqa: RUF005
continue

# SimpleNamespace
if isinstance(exp, SimpleNamespace):
if not isinstance(act, SimpleNamespace):
errs.append(f'{".".join(path)}: expected SimpleNamespace, got {type(act).__name__}')
return False
stack.append((path, exp.__dict__, act.__dict__))
continue

# numpy arrays
if isinstance(exp, np.ndarray) or isinstance(act, np.ndarray):
try:
npt.assert_allclose(act, exp, rtol=rel_tol, atol=abs_tol, equal_nan=nan_ok)
except AssertionError as e:
errs.append(f'{".".join(path) or "<array>"}: {e}')
return False
continue

# scalar or fallback
try:
if not (
isclose(act, exp, rel_tol=rel_tol, abs_tol=abs_tol)
or (nan_ok and math.isnan(act) and math.isnan(exp))
):
errs.append(
f'{".".join(path) or "<value>"}: {act!r} != {exp!r} within (rel={rel_tol}, abs={abs_tol})'
)
return False
except TypeError:
approx = pytest.approx(exp, rel=rel_tol, abs=abs_tol, nan_ok=nan_ok)
if act != approx:
msgs = approx._repr_compare(act)
errs.extend(msgs)
return False

return True

def __repr__(self):
return str(self.expected)

def _repr_compare(self, actual):
# populate errors
_ = actual == self
return self._errors


# Rotation Matrix creation routine
def get_rotation_matrix(channel, angle):
return R.from_euler(channel, angle, degrees=False).as_matrix()
98 changes: 98 additions & 0 deletions tests/test_rotation3D_vs_rotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pypulseq
import pytest
from pypulseq import rotate, rotate3D

from conftest import Approx, get_rotation_matrix

channel_list = ['x', 'y', 'z']
angle_deg_list = [0.0, 0.1, 1.0, 60.0, 90.0, 180.0, 360.0, 400.1, -0.1, -1.0, -90.0, -180.0, -360.0]

grad_list = [
pypulseq.make_trapezoid(channel='x', amplitude=1, duration=13),
pypulseq.make_trapezoid(channel='y', amplitude=1, duration=13),
pypulseq.make_trapezoid(channel='z', amplitude=1, duration=13),
pypulseq.make_trapezoid(channel='x', amplitude=2, duration=5),
pypulseq.make_trapezoid(channel='y', amplitude=2, duration=5),
pypulseq.make_trapezoid(channel='z', amplitude=2, duration=5),
pypulseq.make_extended_trapezoid('x', [0, 5, 1, 3], convert_to_arbitrary=True, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('y', [0, 5, 1, 3], convert_to_arbitrary=True, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('z', [0, 5, 1, 3], convert_to_arbitrary=True, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('x', [0, 5, 1, 3], convert_to_arbitrary=False, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('y', [0, 5, 1, 3], convert_to_arbitrary=False, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('z', [0, 5, 1, 3], convert_to_arbitrary=False, times=[1, 3, 4, 7]),
pypulseq.make_extended_trapezoid('x', [0, 3, 2, 3], convert_to_arbitrary=False, times=[1, 2, 3, 4]),
pypulseq.make_extended_trapezoid('y', [0, 3, 2, 3], convert_to_arbitrary=False, times=[1, 2, 3, 4]),
pypulseq.make_extended_trapezoid('z', [0, 3, 2, 3], convert_to_arbitrary=False, times=[1, 2, 3, 4]),
]


def __list_to_dict(gradient_set):
channel_grad_dict = {}
assert len(gradient_set) <= 3, 'Each gradient set must not have more than three gradients.'
for grad in gradient_set:
assert grad.channel in channel_list, 'Gradients must have channel "x", "y" or "z".'
assert grad.channel not in channel_grad_dict, (
'There must not be two gradients with the same channel in each set.'
)
channel_grad_dict[grad.channel] = grad
return channel_grad_dict


@pytest.mark.filterwarnings('ignore:When using rotate():UserWarning')
@pytest.mark.parametrize('angle_deg', angle_deg_list)
def test_rotation3D_vs_rotation(angle_deg):
"""Compare results of rotate and rotate3D."""
angle_rad = np.deg2rad(angle_deg)

for rotation_axis in channel_list:
rotation_matrix = get_rotation_matrix(rotation_axis, angle_rad)

for grad in grad_list:
grads_rotated = __list_to_dict(rotate(grad, angle=angle_rad, axis=rotation_axis))
grads_rotated3D = __list_to_dict(rotate3D(grad, rotation_matrix=rotation_matrix))

assert grads_rotated3D == Approx(grads_rotated, abs=1e-4, rel=1e-4), (
f'Result of rotate and rotate3D should be the same! Angle: {angle_deg}, Axis: {rotation_axis}, Grad: {grad}'
)


@pytest.mark.filterwarnings('ignore:When using rotate():UserWarning')
@pytest.mark.parametrize('angle_deg', angle_deg_list)
def test_rotation3D_vs_rotation_double(angle_deg):
"""Compare results of rotate and rotate3D."""
angle_rad = np.deg2rad(angle_deg)

for rotation_axis in channel_list:
rotation_matrix = get_rotation_matrix(rotation_axis, angle_rad)

for grad in grad_list:
grads_rotated = rotate(grad, angle=angle_rad, axis=rotation_axis)
grads_rotated_double = __list_to_dict(rotate(*grads_rotated, angle=angle_rad, axis=rotation_axis))

grads_rotated3D = rotate3D(grad, rotation_matrix=rotation_matrix)
grads_rotated3D_double = __list_to_dict(rotate3D(*grads_rotated3D, rotation_matrix=rotation_matrix))

assert grads_rotated3D_double == Approx(grads_rotated_double, abs=1e-4, rel=1e-4), (
f'Result of double rotate and rotate3D should be the same! Angle: {angle_deg}, Axis: {rotation_axis}, Grad: {grad}'
)


@pytest.mark.filterwarnings('ignore:When using rotate():UserWarning')
@pytest.mark.parametrize('angle_deg', angle_deg_list)
def test_rotation3D_vs_rotation_double_2(angle_deg):
"""Compare results of rotate and rotate3D."""
angle_rad = np.deg2rad(angle_deg)

for rotation_axis in channel_list:
rotation_matrix = get_rotation_matrix(rotation_axis, angle_rad)

for grad in grad_list:
grads_rotated = rotate(grad, angle=angle_rad, axis=rotation_axis)
grads_rotated_double = __list_to_dict(rotate(*grads_rotated, angle=angle_rad, axis=rotation_axis))

grads_rotated3D_double = __list_to_dict(rotate3D(grad, rotation_matrix=rotation_matrix @ rotation_matrix))

assert grads_rotated3D_double == Approx(grads_rotated_double, abs=1e-4, rel=1e-4), (
f'Result of second double rotate and rotate3D should be the same! Angle: {angle_deg}, Axis: {rotation_axis}, Grad: {grad}'
)
Loading
Loading