Skip to content
Merged
1 change: 0 additions & 1 deletion packages/data/src/pyearthtools/data/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,3 @@
)
from pyearthtools.data.patterns.parser import ParsingPattern
from pyearthtools.data.patterns.static import Static

Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def _unflatten(data, shape):

if self.shape_attempt:
shape_attempt = self._configure_shape_attempt()
if shape_attempt:
# if self.shape_attempt is truthy then shape_attempt is always truthy.
if shape_attempt: # pragma: no cover
attempts.append((*parsed_shape, *shape_attempt[-1 * self.flatten_dims :])) # type: ignore

for attemp in attempts:
Expand Down Expand Up @@ -315,7 +316,7 @@ def __init__(
"""
super().__init__(
split_tuples=False,
recognised_types=(np.ndarray),
recognised_types=(np.ndarray, tuple),
)
self.record_initialisation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset:
coord_size = dataset[var][discovered_coord].values
coord_size = coord_size if isinstance(coord_size, np.ndarray) else np.array(coord_size)

if coord_size.size == 1 and False:
if coord_size.size == 1 and False: # pragma: nocover # TODO: review why this if stmt was put here.
coord_val = weak_cast_to_int(dataset[var][discovered_coord].values)
new_ds[f"{var}{coord_val}"] = Drop(discovered_coord, ignore_missing=True)(dataset[var])

Expand Down Expand Up @@ -234,7 +234,8 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset | xr.DataArray:
dataset = SetType(**{str(coord): dtype})(dataset)

## Add stored encoding if there
if f"{coord}-dtype" in dataset.attrs:
# this is always False since attributes always get overwritten.
if f"{coord}-dtype" in dataset.attrs: # pragma: no cover
dtype = dataset.attrs.pop(f"{coord}-dtype")
dataset[coord].encoding.update(dtype=dtype)

Expand Down
74 changes: 74 additions & 0 deletions packages/pipeline/tests/operations/numpy/test_numpy_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from pyearthtools.pipeline.operations.numpy import reshape
from unittest.mock import MagicMock

import numpy as np
import pytest
Expand Down Expand Up @@ -115,6 +116,53 @@ def test_Flattener_1_dim():
assert np.all(undo_output == random_array), "Undo Flatten 1 dimension."


def success_then_fail(self, *args, **kwargs):
yield self
raise ValueError()


def test_Flattener_exceptions():
"""Tests all the exceptions that can be raised in the Flattener class."""
# try instantiating flattener with invalid dim
with pytest.raises(ValueError):
reshape.Flattener(flatten_dims=0)

# test undo without apply
f = reshape.Flattener(shape_attempt=(2, 1, 1))
random_array = np.random.randn(4, 3, 5)
with pytest.raises(RuntimeError):
f.undo(random_array)

# _configure_shape_attempt error when apply not run
with pytest.raises(RuntimeError):
f._configure_shape_attempt()

# test undo when flatten_dims unset
output = f.apply(random_array)
f.flatten_dims = None # "accidentally" overwrite the dims
with pytest.raises(RuntimeError):
f.undo(output)

# setup flattener
mock_array = MagicMock()
mock_array.__len__.return_value = 1
mock_array.shape = tuple([1])
mock_array.reshape.return_value = mock_array
f = reshape.Flattener()
output = f.apply(mock_array)

# trigger ValueError in undo when reshape fails
mock_array.reshape.side_effect = ValueError
with pytest.raises(ValueError):
f.undo(mock_array)

# error when input array shape not same rank as shape_attempt
f = reshape.Flattener(shape_attempt=("...", 2))
output = f.apply(random_array)
with pytest.raises(IndexError):
f.undo(output)


def test_Flatten():
f1 = reshape.Flatten(flatten_dims=2)
random_array = np.random.randn(4, 3, 5)
Expand Down Expand Up @@ -157,10 +205,36 @@ def test_Flatten_with_shape_attempt_with_ellipses():
assert f.undo_func(undo_data).shape == (2, 1, 1, 1)


def test_Flatten_with_many_arrays():
incoming_data = (np.zeros((8, 1, 3, 3)), np.zeros((8, 1, 3, 6)))
f = reshape.Flatten()
output = f.apply_func(incoming_data)
assert isinstance(output, tuple)
assert output[0].shape == (8 * 1 * 3 * 3,)
assert output[1].shape == (8 * 1 * 3 * 6,)
# undo
output = f.undo(output)
assert isinstance(output, tuple)
assert output[0].shape == incoming_data[0].shape
assert output[1].shape == incoming_data[1].shape


def test_SwapAxis():
s = reshape.SwapAxis(1, 3)
random_array = np.random.randn(5, 7, 8, 2)
output = s.apply_func(random_array)
assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3"
undo_output = s.undo_func(output)
assert np.all(undo_output == random_array), "Undo axis swap."


def test_Flattener_prod_shape_helper():
"""Tests the Flattener._prod_shape method with numpy input."""
f = reshape.Flattener()
data = np.array(
(
(1, 2, 3),
(4, 5, 6),
)
)
assert f._prod_shape(data) == 6 # product of data shape
36 changes: 31 additions & 5 deletions packages/pipeline/tests/operations/xarray/test_xarray_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
[1.4, 1.5, 3.3],
],
],
coords=[[10, 20], [0, 1, 2], [5, 6, 7]],
coords=[[10.1, 20], [0, 1, 2], [5, 6, 7]],
dims=["height", "lat", "lon"],
)

Expand Down Expand Up @@ -67,6 +67,13 @@ def test_Dimensions_preserve_order():
assert reversed_output.dims == output.dims


def test_Dimensions_noop_undo():
"""Tests that Dimensions undo returns the input as-is when not applied previously."""
d = reshape.Dimensions(["lat"], preserve_order=True)
reversed_output = d.undo_func(SIMPLE_DA1)
assert reversed_output.dims == SIMPLE_DA1.dims


def test_weak_cast_to_int():

wcti = reshape.weak_cast_to_int
Expand All @@ -81,7 +88,14 @@ def test_CoordinateFlatten():
f = reshape.CoordinateFlatten(["height"])
output = f.apply(SIMPLE_DS2)
variables = list(output.keys())
for vbl in ["Temperature10", "Temperature20", "Humidity10", "Humidity20", "WombatsPerKm210", "WombatsPerKm220"]:
for vbl in [
"Temperature10.1",
"Temperature20",
"Humidity10.1",
"Humidity20",
"WombatsPerKm210.1",
"WombatsPerKm220",
]:
assert vbl in variables


Expand All @@ -90,7 +104,7 @@ def test_CoordinateFlatten_complicated_dataset():
f = reshape.CoordinateFlatten(["height"])
output = f.apply(COMPLICATED_DS1)
variables = list(output.keys())
for vbl in ["Temperature10", "Temperature20", "MSLP"]:
for vbl in ["Temperature10.1", "Temperature20", "MSLP"]:
assert vbl in variables


Expand Down Expand Up @@ -120,13 +134,25 @@ def test_CoordinateExpand_reverses_CoordinateFlatten():
variables = list(e_output.keys())
assert "Temperature" in variables

# test noop when non-flatted key is passed to CoordinateExpand
e = reshape.CoordinateExpand("lat")
e_output = e.apply(f_output)
assert list(e_output.keys()) == list(f_output.keys())


def test_undo_CoordinateExpand():
f = reshape.CoordinateFlatten(["height"])
f_output = f.apply(SIMPLE_DS2)
e = reshape.CoordinateExpand(["height"])
e = reshape.CoordinateExpand("height") # should be able to accept non-list/tuple arg
e_output = e.apply(f_output)
e_undone = e.undo(e_output)
variables = list(e_undone.keys())
for vbl in ["Temperature10", "Temperature20", "Humidity10", "Humidity20", "WombatsPerKm210", "WombatsPerKm220"]:
for vbl in [
"Temperature10.1",
"Temperature20",
"Humidity10.1",
"Humidity20",
"WombatsPerKm210.1",
"WombatsPerKm220",
]:
assert vbl in variables