Skip to content

Commit d304506

Browse files
committed
Enhance inverse computation tests to handle infinite values and assert NaN equality
1 parent 0363cdb commit d304506

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

cdl/tests/features/images/operation_unit_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def test_image_inverse() -> None:
231231
with warnings.catch_warnings():
232232
warnings.simplefilter("ignore", category=RuntimeWarning)
233233
exp = np.reciprocal(ima1.data, dtype=float)
234+
exp[np.isinf(exp)] = np.nan
234235
ima2 = cpi.compute_inverse(ima1)
235236
check_array_result("Image inverse", ima2.data, exp)
236237

cdl/tests/features/signals/operation_unit_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from __future__ import annotations
1515

16+
import warnings
17+
1618
import numpy as np
1719
import pytest
1820

@@ -120,7 +122,11 @@ def test_signal_inverse() -> None:
120122
"""Signal inversion validation test."""
121123
s1 = __create_two_signals()[0]
122124
inv_signal = cps.compute_inverse(s1)
123-
check_array_result("Signal inverse", inv_signal.y, 1.0 / s1.y)
125+
with warnings.catch_warnings():
126+
warnings.simplefilter("ignore", category=RuntimeWarning)
127+
exp = 1.0 / s1.y
128+
exp[np.isinf(exp)] = np.nan
129+
check_array_result("Signal inverse", inv_signal.y, exp)
124130

125131

126132
@pytest.mark.validation

cdl/utils/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def check_array_result(
282282
"""Assert that two arrays are almost equal."""
283283
restxt = f"{title}: {__array_to_str(res)} (expected: {__array_to_str(exp)})"
284284
execenv.print(restxt)
285-
assert np.allclose(res, exp, rtol=rtol, atol=atol), restxt
285+
assert np.allclose(res, exp, rtol=rtol, atol=atol, equal_nan=True), restxt
286286
assert res.dtype == exp.dtype, restxt
287287

288288

0 commit comments

Comments
 (0)