diff --git a/news/array_index.rst b/news/array_index.rst new file mode 100644 index 00000000..6f687373 --- /dev/null +++ b/news/array_index.rst @@ -0,0 +1,23 @@ +**Added:** + +* function to return the index of the closest value to the specified value in an array. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/diffpy/utils/diffraction_objects.py b/src/diffpy/utils/diffraction_objects.py index 54374fa7..48fffa7f 100644 --- a/src/diffpy/utils/diffraction_objects.py +++ b/src/diffpy/utils/diffraction_objects.py @@ -248,15 +248,29 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None): array = np.linspace(begin, end, n_steps) return array - def get_angle_index(self, angle): - count = 0 - for i, target in enumerate(self.angles): - if angle == target: - return i - else: - count += 1 - if count >= len(self.angles): - raise IndexError(f"WARNING: no angle {angle} found in angles list") + def get_array_index(self, value, xtype=None): + """ + returns the index of the closest value in the array associated with the specified xtype + + Parameters + ---------- + xtype str + the xtype used to access the array + value float + the target value to search for + + Returns + ------- + the index of the value in the array + """ + + if xtype is None: + xtype = self.input_xtype + array = self.on_xtype(xtype)[0] + if len(array) == 0: + raise ValueError(f"The '{xtype}' array is empty. Please ensure it is initialized.") + i = (np.abs(array - value)).argmin() + return i def _set_xarrays(self, xarray, xtype): self.all_arrays = np.empty(shape=(len(xarray), 4)) diff --git a/tests/test_diffraction_objects.py b/tests/test_diffraction_objects.py index 9db6932e..6fa81d3e 100644 --- a/tests/test_diffraction_objects.py +++ b/tests/test_diffraction_objects.py @@ -1,3 +1,4 @@ +import re from pathlib import Path import numpy as np @@ -211,6 +212,31 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte return np.allclose(actual_array, expected_array) +params_index = [ + # UC1: exact match + ([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]), + # UC2: target value lies in the array, returns the (first) closest index + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 45], [0]), + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "q", 0.25], [0]), + # UC3: target value out of the range, returns the closest index + ([4 * np.pi, np.array([0.25, 0.5, 0.71]), np.array([1, 2, 3]), "q", "q", 0.1], [0]), + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 63], [1]), +] + + +@pytest.mark.parametrize("inputs, expected", params_index) +def test_get_array_index(inputs, expected): + test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3]) + actual = test.get_array_index(value=inputs[5], xtype=inputs[4]) + assert actual == expected[0] + + +def test_get_array_index_bad(): + test = DiffractionObject(wavelength=2 * np.pi, xarray=np.array([]), yarray=np.array([]), xtype="tth") + with pytest.raises(ValueError, match=re.escape("The 'tth' array is empty. Please ensure it is initialized.")): + test.get_array_index(value=30) + + def test_dump(tmp_path, mocker): x, y = np.linspace(0, 5, 6), np.linspace(0, 5, 6) directory = Path(tmp_path)