diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py index bcab81a8..f08a09d5 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/reshape.py @@ -142,9 +142,7 @@ def __init__(self, coordinate: Hashable, skip_missing: bool = False): self._skip_missing = skip_missing def apply_func(self, dataset: xr.Dataset) -> xr.Dataset: - discovered_coord = list(set(self._coordinate).intersection(set(dataset.coords))) - - if len(discovered_coord) == 0: + if self._coordinate not in dataset.coords: if self._skip_missing: return dataset @@ -153,7 +151,7 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset: "Set 'skip_missing' to True to skip this." ) - discovered_coord = str(discovered_coord[0]) + discovered_coord = self._coordinate coords = dataset.coords new_ds = xr.Dataset(coords={co: v for co, v in coords.items() if not co == discovered_coord}) @@ -179,7 +177,7 @@ def apply_func(self, dataset: xr.Dataset) -> xr.Dataset: selected = dataset[var].sel(**{discovered_coord: coord_val}) # type: ignore selected = selected.drop_vars(discovered_coord) # type: ignore - selected.attrs.update(**{discovered_coord: coord_val}) + selected.attrs.update(**{str(discovered_coord): coord_val}) new_ds[f"{var}{coord_val}"] = selected return new_ds