diff --git a/packages/data/src/pyearthtools/data/patterns/__init__.py b/packages/data/src/pyearthtools/data/patterns/__init__.py index 094695bd..35325321 100644 --- a/packages/data/src/pyearthtools/data/patterns/__init__.py +++ b/packages/data/src/pyearthtools/data/patterns/__init__.py @@ -84,4 +84,3 @@ ) from pyearthtools.data.patterns.parser import ParsingPattern from pyearthtools.data.patterns.static import Static - diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index 5a5eda05..3226ea4f 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -238,7 +238,8 @@ def _unflatten(data, shape): if self.shape_attempt: shape_attempt = self._configure_shape_attempt() - if shape_attempt: + # if self.shape_attempt is truthy then shape_attempt is always truthy. + if shape_attempt: # pragma: no cover attempts.append((*parsed_shape, *shape_attempt[-1 * self.flatten_dims :])) # type: ignore for attemp in attempts: @@ -315,7 +316,7 @@ def __init__( """ super().__init__( split_tuples=False, - recognised_types=(np.ndarray), + recognised_types=(np.ndarray, tuple), ) self.record_initialisation() diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py index bcab81a8..d748672a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py @@ -169,7 +169,7 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset: coord_size = dataset[var][discovered_coord].values coord_size = coord_size if isinstance(coord_size, np.ndarray) else np.array(coord_size) - if coord_size.size == 1 and False: + if coord_size.size == 1 and False: # pragma: nocover # TODO: review why this if stmt was put here. coord_val = weak_cast_to_int(dataset[var][discovered_coord].values) new_ds[f"{var}{coord_val}"] = Drop(discovered_coord, ignore_missing=True)(dataset[var]) @@ -234,7 +234,8 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset | xr.DataArray: dataset = SetType(**{str(coord): dtype})(dataset) ## Add stored encoding if there - if f"{coord}-dtype" in dataset.attrs: + # this is always False since attributes always get overwritten. + if f"{coord}-dtype" in dataset.attrs: # pragma: no cover dtype = dataset.attrs.pop(f"{coord}-dtype") dataset[coord].encoding.update(dtype=dtype) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index d788ec67..ad698caf 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -13,6 +13,7 @@ # limitations under the License. from pyearthtools.pipeline.operations.numpy import reshape +from unittest.mock import MagicMock import numpy as np import pytest @@ -115,6 +116,53 @@ def test_Flattener_1_dim(): assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." +def success_then_fail(self, *args, **kwargs): + yield self + raise ValueError() + + +def test_Flattener_exceptions(): + """Tests all the exceptions that can be raised in the Flattener class.""" + # try instantiating flattener with invalid dim + with pytest.raises(ValueError): + reshape.Flattener(flatten_dims=0) + + # test undo without apply + f = reshape.Flattener(shape_attempt=(2, 1, 1)) + random_array = np.random.randn(4, 3, 5) + with pytest.raises(RuntimeError): + f.undo(random_array) + + # _configure_shape_attempt error when apply not run + with pytest.raises(RuntimeError): + f._configure_shape_attempt() + + # test undo when flatten_dims unset + output = f.apply(random_array) + f.flatten_dims = None # "accidentally" overwrite the dims + with pytest.raises(RuntimeError): + f.undo(output) + + # setup flattener + mock_array = MagicMock() + mock_array.__len__.return_value = 1 + mock_array.shape = tuple([1]) + mock_array.reshape.return_value = mock_array + f = reshape.Flattener() + output = f.apply(mock_array) + + # trigger ValueError in undo when reshape fails + mock_array.reshape.side_effect = ValueError + with pytest.raises(ValueError): + f.undo(mock_array) + + # error when input array shape not same rank as shape_attempt + f = reshape.Flattener(shape_attempt=("...", 2)) + output = f.apply(random_array) + with pytest.raises(IndexError): + f.undo(output) + + def test_Flatten(): f1 = reshape.Flatten(flatten_dims=2) random_array = np.random.randn(4, 3, 5) @@ -157,6 +205,20 @@ def test_Flatten_with_shape_attempt_with_ellipses(): assert f.undo_func(undo_data).shape == (2, 1, 1, 1) +def test_Flatten_with_many_arrays(): + incoming_data = (np.zeros((8, 1, 3, 3)), np.zeros((8, 1, 3, 6))) + f = reshape.Flatten() + output = f.apply_func(incoming_data) + assert isinstance(output, tuple) + assert output[0].shape == (8 * 1 * 3 * 3,) + assert output[1].shape == (8 * 1 * 3 * 6,) + # undo + output = f.undo(output) + assert isinstance(output, tuple) + assert output[0].shape == incoming_data[0].shape + assert output[1].shape == incoming_data[1].shape + + def test_SwapAxis(): s = reshape.SwapAxis(1, 3) random_array = np.random.randn(5, 7, 8, 2) @@ -164,3 +226,15 @@ def test_SwapAxis(): assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3" undo_output = s.undo_func(output) assert np.all(undo_output == random_array), "Undo axis swap." + + +def test_Flattener_prod_shape_helper(): + """Tests the Flattener._prod_shape method with numpy input.""" + f = reshape.Flattener() + data = np.array( + ( + (1, 2, 3), + (4, 5, 6), + ) + ) + assert f._prod_shape(data) == 6 # product of data shape diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_reshape.py b/packages/pipeline/tests/operations/xarray/test_xarray_reshape.py index 48dd8d57..250bc494 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_reshape.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_reshape.py @@ -30,7 +30,7 @@ [1.4, 1.5, 3.3], ], ], - coords=[[10, 20], [0, 1, 2], [5, 6, 7]], + coords=[[10.1, 20], [0, 1, 2], [5, 6, 7]], dims=["height", "lat", "lon"], ) @@ -67,6 +67,13 @@ def test_Dimensions_preserve_order(): assert reversed_output.dims == output.dims +def test_Dimensions_noop_undo(): + """Tests that Dimensions undo returns the input as-is when not applied previously.""" + d = reshape.Dimensions(["lat"], preserve_order=True) + reversed_output = d.undo_func(SIMPLE_DA1) + assert reversed_output.dims == SIMPLE_DA1.dims + + def test_weak_cast_to_int(): wcti = reshape.weak_cast_to_int @@ -81,7 +88,14 @@ def test_CoordinateFlatten(): f = reshape.CoordinateFlatten(["height"]) output = f.apply(SIMPLE_DS2) variables = list(output.keys()) - for vbl in ["Temperature10", "Temperature20", "Humidity10", "Humidity20", "WombatsPerKm210", "WombatsPerKm220"]: + for vbl in [ + "Temperature10.1", + "Temperature20", + "Humidity10.1", + "Humidity20", + "WombatsPerKm210.1", + "WombatsPerKm220", + ]: assert vbl in variables @@ -90,7 +104,7 @@ def test_CoordinateFlatten_complicated_dataset(): f = reshape.CoordinateFlatten(["height"]) output = f.apply(COMPLICATED_DS1) variables = list(output.keys()) - for vbl in ["Temperature10", "Temperature20", "MSLP"]: + for vbl in ["Temperature10.1", "Temperature20", "MSLP"]: assert vbl in variables @@ -120,13 +134,25 @@ def test_CoordinateExpand_reverses_CoordinateFlatten(): variables = list(e_output.keys()) assert "Temperature" in variables + # test noop when non-flatted key is passed to CoordinateExpand + e = reshape.CoordinateExpand("lat") + e_output = e.apply(f_output) + assert list(e_output.keys()) == list(f_output.keys()) + def test_undo_CoordinateExpand(): f = reshape.CoordinateFlatten(["height"]) f_output = f.apply(SIMPLE_DS2) - e = reshape.CoordinateExpand(["height"]) + e = reshape.CoordinateExpand("height") # should be able to accept non-list/tuple arg e_output = e.apply(f_output) e_undone = e.undo(e_output) variables = list(e_undone.keys()) - for vbl in ["Temperature10", "Temperature20", "Humidity10", "Humidity20", "WombatsPerKm210", "WombatsPerKm220"]: + for vbl in [ + "Temperature10.1", + "Temperature20", + "Humidity10.1", + "Humidity20", + "WombatsPerKm210.1", + "WombatsPerKm220", + ]: assert vbl in variables