diff --git a/.gitignore b/.gitignore index c67bdd8a..cae40c85 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ atlite.egg-info/ test/*.nc dev-scripts/ dev/ +benchmarks/ examples/*.nc examples/*.csv examples/*.zip diff --git a/atlite/aggregate.py b/atlite/aggregate.py index 2f6b90d2..76ac78d1 100644 --- a/atlite/aggregate.py +++ b/atlite/aggregate.py @@ -1,20 +1,206 @@ # SPDX-FileCopyrightText: Contributors to atlite # # SPDX-License-Identifier: MIT -"""Functions for aggregating results.""" +"""Functions for aggregating and resolving spatial/temporal results.""" from __future__ import annotations -from typing import TYPE_CHECKING, cast +import warnings +from typing import TYPE_CHECKING, Any, Literal, cast -import dask +import geopandas as gpd +import numpy as np +import pandas as pd import xarray as xr +from scipy.sparse import csr_matrix + +from atlite.gis import spdiag if TYPE_CHECKING: - import pandas as pd from scipy.sparse import spmatrix from atlite._types import DataArray + from atlite.cutout import Cutout + + +def ensure_index_name(index: pd.Index) -> pd.Index: + """Return *index* with name ``"dim_0"`` when it has no name.""" + if index.name is None: + return index.rename("dim_0") + return index + + +def resolve_matrix( + cutout: Cutout, + matrix: Any, + index: Any, + shapes: Any, + shapes_crs: int, + layout: Any, +) -> tuple[csr_matrix | None, pd.Index | None]: + """Resolve *matrix*, *shapes* and *layout* into a sparse matrix and index. + + Validates the inputs, builds an indicator matrix from *shapes* when + needed, and folds *layout* capacities into the matrix. Returns + ``(None, None)`` when no spatial aggregation is requested. + """ + if matrix is not None: + if shapes is not None: + raise ValueError( + "Passing matrix and shapes is ambiguous. Pass only one of them." + ) + + if isinstance(matrix, xr.DataArray): + coords = matrix.indexes.get(matrix.dims[1]).to_frame(index=False) + if not np.array_equal(coords[["x", "y"]], cutout.grid[["x", "y"]]): + raise ValueError( + "Matrix spatial coordinates not aligned with cutout spatial " + "coordinates." + ) + if index is None: + index = matrix + + if not matrix.ndim == 2: + raise ValueError("Matrix not 2-dimensional.") + + matrix = csr_matrix(matrix) + + if shapes is not None: + geoseries_like = (pd.Series, gpd.GeoDataFrame, gpd.GeoSeries) + if isinstance(shapes, geoseries_like) and index is None: + index = shapes.index + matrix = cutout.indicatormatrix(shapes, shapes_crs) + + if layout is not None: + assert isinstance(layout, xr.DataArray) + layout = layout.reindex_like(cutout.data).stack(spatial=["y", "x"]) + if matrix is None: + matrix = csr_matrix(layout.expand_dims("new")) + else: + matrix = csr_matrix(matrix) * spdiag(layout) + + if matrix is not None and index is None: + index = pd.RangeIndex(matrix.shape[0]) + + return matrix, index + + +def normalize_aggregate_time( + aggregate_time: Literal["sum", "mean", "legacy"] | None, + no_spatial: bool, + capacity_factor: bool = False, + capacity_factor_timeseries: bool = False, +) -> Literal["sum", "mean"] | None: + """Normalise the *aggregate_time* parameter to ``"sum"``, ``"mean"`` or ``None``. + + Handles the deprecated ``"legacy"`` value and the deprecated + *capacity_factor* / *capacity_factor_timeseries* flags, emitting + :class:`FutureWarning` where appropriate. + """ + if aggregate_time not in ("sum", "mean", "legacy", None): + raise ValueError( + f"aggregate_time must be 'sum', 'mean', 'legacy', or None, " + f"got {aggregate_time!r}" + ) + + if aggregate_time == "legacy": + warnings.warn( + "aggregate_time='legacy' is deprecated and will be removed in a " + "future release. Pass 'sum', 'mean', or None explicitly.", + FutureWarning, + stacklevel=3, + ) + + if capacity_factor or capacity_factor_timeseries: + if aggregate_time != "legacy": + raise ValueError( + "Cannot use 'aggregate_time' together with deprecated " + "'capacity_factor' or 'capacity_factor_timeseries'." + ) + if capacity_factor: + warnings.warn( + "capacity_factor is deprecated. Use aggregate_time='mean' instead.", + FutureWarning, + stacklevel=3, + ) + aggregate_time = "mean" + if capacity_factor_timeseries: + warnings.warn( + "capacity_factor_timeseries is deprecated. " + "Use aggregate_time=None instead.", + FutureWarning, + stacklevel=3, + ) + aggregate_time = None + + if aggregate_time == "legacy": + return "sum" if no_spatial else None + return aggregate_time + + +def reduce_time( + da: xr.DataArray, method: Literal["sum", "mean"] | None +) -> xr.DataArray: + """Reduce *da* along the ``time`` dimension using *method*. + + Returns *da* unchanged when *method* is ``None``. + """ + if method == "sum": + return da.sum("time", keep_attrs=True) + if method == "mean": + return da.mean("time", keep_attrs=True) + return da + + +def build_capacity_array(matrix: Any, index: pd.Index) -> xr.DataArray: + """Sum *matrix* columns to obtain the installed capacity per bus.""" + capacity = xr.DataArray(np.asarray(matrix.sum(-1)).flatten(), [index]) + capacity.attrs["units"] = "MW" + return capacity + + +def wrap_matrix_result( + data: np.ndarray, + time: xr.DataArray, + index: pd.Index, +) -> DataArray: + """Wrap a ``(time, n_regions)`` numpy array into a labelled DataArray.""" + index = ensure_index_name(index) + return xr.DataArray( + data, + dims=("time", index.name), + coords={"time": time, index.name: index}, + ) + + +def finalize_aggregated_result( + result: xr.DataArray, + matrix: Any, + index: pd.Index, + per_unit: bool, + return_capacity: bool, + aggregate_time_method: Literal["sum", "mean"] | None, +) -> DataArray | tuple[DataArray, DataArray]: + """Apply per-unit normalisation, time aggregation and capacity extraction. + + Returns either the finalised DataArray or a ``(result, capacity)`` tuple + when *return_capacity* is ``True``. + """ + capacity = None + if per_unit or return_capacity: + capacity = build_capacity_array(matrix, index) + + if per_unit: + result = (result / capacity.where(capacity != 0)).fillna(0.0) + result.attrs["units"] = "p.u." + else: + result.attrs["units"] = "MW" + + result = reduce_time(result, aggregate_time_method) + + if return_capacity: + return result, capacity + return result def aggregate_matrix( @@ -22,26 +208,26 @@ def aggregate_matrix( matrix: spmatrix, index: pd.Index, ) -> DataArray: - """ - Aggregate spatial data with a sparse matrix. + """Aggregate spatial data with a sparse *matrix*. Parameters ---------- da : xarray.DataArray DataArray with spatial dimensions ``y`` and ``x``. matrix : scipy.sparse.spmatrix - Aggregation matrix mapping flattened spatial cells to ``index``. + Aggregation matrix mapping flattened spatial cells to *index*. index : pandas.Index Index defining the aggregated dimension. Returns ------- xarray.DataArray - Aggregated data indexed by ``index`` and, if present, time. + Aggregated data indexed by *index* and, if present, time. """ - if index.name is None: - index = index.rename("dim_0") - if isinstance(da.data, dask.array.core.Array): + import dask as _dask + + index = ensure_index_name(index) + if isinstance(da.data, _dask.array.core.Array): da = da.stack(spatial=("y", "x")) da = da.chunk({"spatial": -1}) result = xr.apply_ufunc( diff --git a/atlite/convert.py b/atlite/convert.py index 14dd99cd..d107f794 100644 --- a/atlite/convert.py +++ b/atlite/convert.py @@ -7,7 +7,6 @@ import datetime as dt import logging -import warnings from collections import namedtuple from operator import itemgetter from pathlib import Path @@ -21,13 +20,17 @@ from dask.array import absolute, arccos, cos, maximum, mod, radians, sin, sqrt from dask.diagnostics import ProgressBar from numpy import pi -from scipy.sparse import csr_matrix from atlite import csp as cspm from atlite import hydro as hydrom from atlite import wind as windm -from atlite.aggregate import aggregate_matrix -from atlite.gis import spdiag +from atlite.aggregate import ( + aggregate_matrix, + finalize_aggregated_result, + normalize_aggregate_time, + reduce_time, + resolve_matrix, +) from atlite.pv.irradiation import TiltedIrradiation from atlite.pv.orientation import SurfaceOrientation, get_orientation from atlite.pv.solar_panel_model import SolarPanelModel @@ -39,8 +42,6 @@ windturbine_smooth, ) -logger = logging.getLogger(__name__) - if TYPE_CHECKING: from collections.abc import Callable @@ -59,14 +60,7 @@ from atlite.resource import CSPConfig, PanelConfig, TurbineConfig -def _aggregate_time( - da: xr.DataArray, method: Literal["sum", "mean"] | None -) -> xr.DataArray: - if method == "sum": - return da.sum("time", keep_attrs=True) - if method == "mean": - return da.mean("time", keep_attrs=True) - return da +logger = logging.getLogger(__name__) def convert_and_aggregate( @@ -84,6 +78,7 @@ def convert_and_aggregate( capacity_factor_timeseries: bool = False, show_progress: bool = False, dask_kwargs: dict[str, Any] | None = None, + backend: Literal["auto", "dask", "streaming"] = "auto", **convert_kwds: Any, ) -> Any: """ @@ -134,7 +129,13 @@ def convert_and_aggregate( show_progress : boolean, default False Whether to show a progress bar. dask_kwargs : dict, default {} - Dict with keyword arguments passed to ``dask.compute``. + Dict with keyword arguments passed to ``load`` when using + ``backend="dask"``. + backend : "auto", "dask", or "streaming", default "auto" + Execution backend. ``"auto"`` prefers streaming when available and + otherwise uses dask-backed xarray loading. ``"dask"`` always uses the + dask-backed path. ``"streaming"`` requires a file-backed cutout that + supports streaming. **convert_kwds : Any Additional keyword arguments passed to ``convert_func``. @@ -176,121 +177,78 @@ def convert_and_aggregate( pv : Generate solar PV generation time-series. """ - if aggregate_time not in ("sum", "mean", "legacy", None): + if backend not in ("auto", "dask", "streaming"): raise ValueError( - f"aggregate_time must be 'sum', 'mean', 'legacy', or None, " - f"got {aggregate_time!r}" + f"backend must be 'auto', 'dask', or 'streaming', got {backend!r}" ) - if aggregate_time == "legacy": - warnings.warn( - "aggregate_time='legacy' is deprecated and will be removed in a " - "future release. Pass 'sum', 'mean', or None explicitly.", - FutureWarning, - stacklevel=2, - ) + no_args = all(v is None for v in [layout, shapes, matrix]) + agg = normalize_aggregate_time( + aggregate_time, no_args, capacity_factor, capacity_factor_timeseries + ) - if capacity_factor or capacity_factor_timeseries: - if aggregate_time != "legacy": - raise ValueError( - "Cannot use 'aggregate_time' together with deprecated " - "'capacity_factor' or 'capacity_factor_timeseries'." - ) - if capacity_factor: - warnings.warn( - "capacity_factor is deprecated. Use aggregate_time='mean' instead.", - FutureWarning, - stacklevel=2, - ) - aggregate_time = "mean" - if capacity_factor_timeseries: - warnings.warn( - "capacity_factor_timeseries is deprecated. " - "Use aggregate_time=None instead.", - FutureWarning, - stacklevel=2, - ) - aggregate_time = None + if no_args and (per_unit or return_capacity): + raise ValueError( + "One of `matrix`, `shapes` and `layout` must be " + "given for `per_unit` or `return_capacity`" + ) func_name = convert_func.__name__.replace("convert_", "") logger.info("Convert and aggregate '%s'.", func_name) - da = convert_func(cutout.data, **convert_kwds) dask_kwargs = dask_kwargs or {} - no_args = all(v is None for v in [layout, shapes, matrix]) - - if no_args: - if per_unit or return_capacity: - raise ValueError( - "One of `matrix`, `shapes` and `layout` must be " - "given for `per_unit` or `return_capacity`" + if dask_kwargs and backend != "dask": + raise ValueError("dask_kwargs require backend='dask'.") + + matrix, index = resolve_matrix(cutout, matrix, index, shapes, shapes_crs, layout) + + if backend != "dask": + from atlite.streaming import can_stream, stream_conversion + + stream_supported = can_stream(cutout) + if backend == "streaming" and not stream_supported: + raise ValueError("backend='streaming' requires a streamable cutout.") + + if stream_supported: + result = stream_conversion( + cutout, + convert_func, + matrix=matrix, + index=index, + per_unit=per_unit, + return_capacity=return_capacity, + aggregate_time=agg, + show_progress=show_progress, + convert_kwds=convert_kwds, ) - - agg = "sum" if aggregate_time == "legacy" else aggregate_time - res = _aggregate_time(da, agg) - return maybe_progressbar(res, show_progress, **dask_kwargs) - - if matrix is not None: - if shapes is not None: - raise ValueError( - "Passing matrix and shapes is ambiguous. Pass only one of them." - ) - - if isinstance(matrix, xr.DataArray): - coords = matrix.indexes.get(matrix.dims[1]).to_frame(index=False) - if not np.array_equal(coords[["x", "y"]], cutout.grid[["x", "y"]]): + if result is not None: + return result + if backend == "streaming": raise ValueError( - "Matrix spatial coordinates not aligned with cutout spatial " - "coordinates." + f"backend='streaming' is not supported for {func_name!r}." ) + logger.debug("Streaming fallback to dask for '%s'.", func_name) - if index is None: - index = matrix - - if not matrix.ndim == 2: - raise ValueError("Matrix not 2-dimensional.") - - matrix = csr_matrix(matrix) - - if shapes is not None: - geoseries_like = (pd.Series, gpd.GeoDataFrame, gpd.GeoSeries) - if isinstance(shapes, geoseries_like) and index is None: - index = shapes.index - - matrix = cutout.indicatormatrix(shapes, shapes_crs) - - if layout is not None: - assert isinstance(layout, xr.DataArray) - layout = layout.reindex_like(cutout.data).stack(spatial=["y", "x"]) - - if matrix is None: - matrix = csr_matrix(layout.expand_dims("new")) - else: - matrix = csr_matrix(matrix) * spdiag(layout) + da = convert_func(cutout.data, **convert_kwds) - # From here on, matrix is defined and ensured to be a csr matrix. - if index is None: - index = pd.RangeIndex(matrix.shape[0]) + if no_args: + res = reduce_time(da, agg) + return maybe_progressbar(res, show_progress, **dask_kwargs) results = aggregate_matrix(da, matrix=matrix, index=index) - - if per_unit or return_capacity: - caps = matrix.sum(-1) - capacity = xr.DataArray(np.asarray(caps).flatten(), [index]) - capacity.attrs["units"] = "MW" - - if per_unit: - results = (results / capacity.where(capacity != 0)).fillna(0.0) - results.attrs["units"] = "p.u." - else: - results.attrs["units"] = "MW" - - if aggregate_time != "legacy": - results = _aggregate_time(results, aggregate_time) + finalized = finalize_aggregated_result( + results, + matrix, + index, + per_unit, + return_capacity, + agg, + ) if return_capacity: - return maybe_progressbar(results, show_progress, **dask_kwargs), capacity - return maybe_progressbar(results, show_progress, **dask_kwargs) + result, capacity = finalized + return maybe_progressbar(result, show_progress, **dask_kwargs), capacity + return maybe_progressbar(finalized, show_progress, **dask_kwargs) def maybe_progressbar( diff --git a/atlite/cutout.py b/atlite/cutout.py index b48342ff..b5c7f4b2 100644 --- a/atlite/cutout.py +++ b/atlite/cutout.py @@ -75,6 +75,22 @@ logger = logging.getLogger(__name__) +def _storage_aligned_chunks(ds: xr.Dataset) -> dict[str, int] | None: + for var in ds.data_vars.values(): + chunksizes = var.encoding.get("chunksizes") + if not chunksizes: + continue + dims = var.dims[: len(chunksizes)] + chunks = { + dim: size + for dim, size in zip(dims, chunksizes, strict=False) + if dim == "time" + } + if chunks: + return chunks + return None + + class Cutout: """ Cutout base class. @@ -166,17 +182,22 @@ def __init__(self, path: PathLike, **cutoutparams: Any) -> None: """ path = Path(path).with_suffix(".nc") - chunks = cutoutparams.pop("chunks", {"time": 100}) - if isinstance(chunks, dict): - storable_chunks = {f"chunksize_{k}": v for k, v in (chunks or {}).items()} - else: - storable_chunks = {} + default_chunks = object() + chunks = cutoutparams.pop("chunks", default_chunks) # Three cases. First, cutout exists -> take the data. # Second, data is given -> take it. Third, else -> build a new cutout if path.is_file(): data = xr.open_dataset(str(path)) - data = data.chunk(chunks) + if chunks is default_chunks: + chunks = _storage_aligned_chunks(data) or {"time": 100} + if chunks is not None: + data = data.chunk(chunks) + storable_chunks = ( + {f"chunksize_{k}": v for k, v in (chunks or {}).items()} + if isinstance(chunks, dict) + else {} + ) data.attrs.update(storable_chunks) if cutoutparams: warn( @@ -185,8 +206,22 @@ def __init__(self, path: PathLike, **cutoutparams: Any) -> None: stacklevel=2, ) elif "data" in cutoutparams: + if chunks is default_chunks: + chunks = {"time": 100} + storable_chunks = ( + {f"chunksize_{k}": v for k, v in (chunks or {}).items()} + if isinstance(chunks, dict) + else {} + ) data = cutoutparams.pop("data") else: + if chunks is default_chunks: + chunks = {"time": 100} + storable_chunks = ( + {f"chunksize_{k}": v for k, v in (chunks or {}).items()} + if isinstance(chunks, dict) + else {} + ) logger.info("Building new cutout %s", path) if "bounds" in cutoutparams: diff --git a/atlite/gis.py b/atlite/gis.py index b22b7c36..c4b8ed7c 100644 --- a/atlite/gis.py +++ b/atlite/gis.py @@ -6,11 +6,11 @@ from __future__ import annotations import logging -import multiprocessing as mp +import threading from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import TYPE_CHECKING, Any, cast -from warnings import catch_warnings, simplefilter import geopandas as gpd import numpy as np @@ -27,6 +27,7 @@ from rasterio.plot import show from rasterio.warp import transform_bounds from scipy.ndimage import binary_dilation as dilation +from scipy.ndimage import distance_transform_cdt, generate_binary_structure from shapely.ops import transform from shapely.strtree import STRtree from tqdm import tqdm @@ -451,25 +452,11 @@ def shape_availability( masked, transform = projected_mask( d["raster"], geometry, transform, shape, excluder.crs, **kwargs ) - if d["codes"]: - if callable(d["codes"]): - masked_ = d["codes"](masked).astype(bool) - else: - masked_ = isin(masked, d["codes"]) - else: - masked_ = masked.astype(bool) - - if d["invert"]: - masked_ = ~masked_ - if d["buffer"]: - iterations = int(d["buffer"] / excluder.res) + 1 - masked_ = dilation(masked_, iterations=iterations) - - exclusions = exclusions | masked_ + exclusions |= apply_exclusion_entry(d, masked, excluder.res) for d in excluder.geometries: masked = ~geometry_mask(d["geometry"], shape, transform, invert=d["invert"]) - exclusions = exclusions | masked + exclusions |= masked return ~exclusions, transform @@ -636,12 +623,13 @@ def open_files(self) -> None: else: assert isinstance(raster, rio.DatasetReader) - # Check if the raster has a valid CRS if not raster.crs: if d["crs"]: raster._crs = CRS(d["crs"]) else: raise ValueError(f"CRS of {raster} is invalid, please provide it.") + elif d["crs"]: + raster._crs = CRS(d["crs"]) d["raster"] = raster for d in self.geometries: @@ -830,36 +818,6 @@ def plot_shape_availability( return ax -_mp_shapes: GeoSeries -_mp_excluder: ExclusionContainer -_mp_dst_transform: rio.Affine -_mp_dst_crs: CrsLike -_mp_dst_shapes: tuple[int, int] - - -def _init_process( - shapes_: GeoSeries, - excluder_: ExclusionContainer, - dst_transform_: rio.Affine, - dst_crs_: CrsLike, - dst_shapes_: tuple[int, int], -) -> None: - global _mp_shapes, _mp_excluder, _mp_dst_transform, _mp_dst_crs, _mp_dst_shapes - _mp_shapes, _mp_excluder = shapes_, excluder_ - _mp_dst_transform, _mp_dst_crs, _mp_dst_shapes = ( - dst_transform_, - dst_crs_, - dst_shapes_, - ) - - -def _process_func(i: Any) -> NDArray: - args = (_mp_excluder, _mp_dst_transform, _mp_dst_crs, _mp_dst_shapes) - with catch_warnings(): - simplefilter("ignore") - return shape_availability_reprojected(_mp_shapes.loc[[i]], *args)[0] - - def compute_availabilitymatrix( cutout: Any, shapes: GeoDataFrame | GeoSeries, @@ -870,9 +828,9 @@ def compute_availabilitymatrix( """ Compute the eligible share within cutout cells in the overlap with shapes. - For parallel calculation (nprocesses not None) the excluder must not be - initialized and all raster references must be strings. Otherwise processes - are colliding when reading from one common rasterio.DatasetReader. + When ``nprocesses`` is set, raster data is pre-read into memory and + per-shape processing runs in parallel using threads. Each thread gets + its own rasterio file handles, so there is no file-handle sharing issue. Parameters ---------- @@ -884,12 +842,10 @@ def compute_availabilitymatrix( Container of all meta data or objects which to exclude, i.e. rasters and geometries. nprocesses : int, optional - Number of processes to use for calculating the matrix. The paralle- - lization can heavily boost the calculation speed. The default is None. + Number of threads for parallel calculation. The default is None + (serial). disable_progressbar: bool, optional - Disable the progressbar if nprocesses is not None. Then the `map` - function instead of the `imap` function is used for the multiprocessing - pool. This speeds up the calculation. + Disable the progressbar. The default is True. Returns ------- @@ -911,7 +867,10 @@ def compute_availabilitymatrix( shapes = shapes.geometry if isinstance(shapes, gpd.GeoDataFrame) else shapes shapes = shapes.to_crs(excluder.crs) - args = (excluder, cutout.transform_r, cutout.crs, cutout.shape) + dst_transform = cutout.transform_r + dst_crs = cutout.crs + dst_shape = cutout.shape + tqdm_kwargs = { "ascii": False, "unit": " gridcells", @@ -919,33 +878,39 @@ def compute_availabilitymatrix( "desc": "Compute availability matrix", } + if not excluder.all_open: + excluder.open_files() + + cache = RasterCache(excluder) + if nprocesses is None: if not disable_progressbar: iterator = tqdm(shapes.index, **tqdm_kwargs) else: iterator = shapes.index - with catch_warnings(): - simplefilter("ignore") - availability = [] - for i in iterator: - _ = shape_availability_reprojected(shapes.loc[[i]], *args)[0] - availability.append(_) + availability = [ + reproject_cached_availability( + shapes.loc[[i]], + excluder, + cache, + dst_transform, + dst_crs, + dst_shape, + ) + for i in iterator + ] else: - assert excluder.all_closed, ( - "For parallelization all raster files in excluder must be closed" + availability = compute_availability_threaded( + shapes, + excluder, + cache, + dst_transform, + dst_crs, + dst_shape, + nprocesses, + disable_progressbar, + tqdm_kwargs, ) - with mp.get_context("spawn").Pool( - processes=nprocesses, - initializer=_init_process, - initargs=(shapes, *args), - maxtasksperchild=20, - ) as pool: - if disable_progressbar: - availability = list(pool.map(_process_func, shapes.index)) - else: - availability = list( - tqdm(pool.imap(_process_func, shapes.index), **tqdm_kwargs) - ) availability_arr = np.stack(availability)[:, ::-1] # flip axis, see Notes if availability_arr.ndim == 4: @@ -954,6 +919,463 @@ def compute_availabilitymatrix( return xr.DataArray(availability_arr, coords=coords) +def reproject_cached_availability( + geometry: GeoSeries, + excluder: ExclusionContainer, + cache: RasterCache, + dst_transform: rio.Affine, + dst_crs: CrsLike, + dst_shape: tuple[int, int], +) -> NDArray: + """ + Compute cached availability for a geometry and reproject to a cutout grid. + + Combines :func:`shape_availability_cached`, :func:`pad_extent`, and + ``rasterio.warp.reproject`` into a single call. + + Parameters + ---------- + geometry : geopandas.GeoSeries + Geometry for which availability is computed. + excluder : ExclusionContainer + Exclusion container with rasters and geometries to exclude. + cache : RasterCache + Pre-loaded raster cache built from *excluder*. + dst_transform : rasterio.Affine + Target grid affine transform. + dst_crs : CRS-like + Target coordinate reference system. + dst_shape : tuple of int + Target grid shape ``(rows, cols)``. + + Returns + ------- + numpy.ndarray + Reprojected availability matrix with values in ``[0, 1]``. + + """ + avail, trans = shape_availability_cached(geometry, excluder, cache) + padded, pt = pad_extent(avail, trans, dst_transform, excluder.crs, dst_crs) + return rio.warp.reproject( + padded.astype(np.uint8), + empty(dst_shape), + resampling=rio.warp.Resampling.average, + src_transform=pt, + dst_transform=dst_transform, + src_crs=excluder.crs, + dst_crs=dst_crs, + )[0] + + +def compute_availability_threaded( + shapes: GeoSeries, + excluder: ExclusionContainer, + cache: RasterCache, + dst_transform: rio.Affine, + dst_crs: CrsLike, + dst_shape: tuple[int, int], + nprocesses: int, + disable_progressbar: bool, + tqdm_kwargs: dict[str, Any], +) -> list[NDArray]: + """ + Process shapes in parallel threads with per-thread file handles. + + Each worker thread gets its own :class:`ExclusionContainer` with + independent rasterio file handles while sharing the read-only *cache*. + All thread-local file handles are closed when the pool shuts down. + + Parameters + ---------- + shapes : geopandas.GeoSeries + Geometries to process, indexed by shape identifier. + excluder : ExclusionContainer + Template exclusion container (replicated per thread). + cache : RasterCache + Pre-loaded raster cache shared across threads. + dst_transform : rasterio.Affine + Target grid affine transform. + dst_crs : CRS-like + Target coordinate reference system. + dst_shape : tuple of int + Target grid shape ``(rows, cols)``. + nprocesses : int + Number of worker threads. + disable_progressbar : bool + Whether to suppress the progress bar. + tqdm_kwargs : dict + Extra keyword arguments forwarded to :func:`tqdm.tqdm`. + + Returns + ------- + list of numpy.ndarray + Reprojected availability arrays, one per shape. + + """ + tls = threading.local() + raster_paths = [ + ( + d["raster"].name + if isinstance(d["raster"], rio.DatasetReader) + else d["raster"] + ) + for d in excluder.rasters + ] + geometry_data = [ + ( + d["geometry"].copy(), + d.get("buffer", 0), + d.get("invert", False), + d.get("_buffered", False), + ) + for d in excluder.geometries + ] + raster_entries = [ + {k: v for k, v in d.items() if k not in ("raster", "geometry")} + for d in excluder.rasters + ] + + thread_excluders: list[ExclusionContainer] = [] + lock = threading.Lock() + + def _get_thread_excluder() -> ExclusionContainer: + if getattr(tls, "excluder", None) is None: + exc = ExclusionContainer(crs=excluder.crs, res=excluder.res) + for path, entry in zip(raster_paths, raster_entries, strict=True): + exc.add_raster(path, **entry) + for geom_data, buffer, invert, buffered in geometry_data: + exc.add_geometry(geom_data, buffer=buffer, invert=invert) + if buffered: + exc.geometries[-1]["_buffered"] = True + exc.open_files() + tls.excluder = exc + with lock: + thread_excluders.append(exc) + return tls.excluder + + def _process(i: Any) -> NDArray: + thread_excluder = _get_thread_excluder() + return reproject_cached_availability( + shapes.loc[[i]], + thread_excluder, + cache, + dst_transform, + dst_crs, + dst_shape, + ) + + try: + with ThreadPoolExecutor(max_workers=nprocesses) as pool: + if disable_progressbar: + return list(pool.map(_process, shapes.index)) + return list(tqdm(pool.map(_process, shapes.index), **tqdm_kwargs)) + finally: + for exc in thread_excluders: + for d in exc.rasters: + r = d["raster"] + if isinstance(r, rio.DatasetReader) and not r.closed: + r.close() + + +def fast_isin(arr: NDArray, codes: Sequence[int]) -> NDArray: + """ + Test element membership using a lookup table for small-integer arrays. + + For ``uint8`` arrays or integer arrays with max value below 65 536, + builds a boolean LUT for O(1) per-element lookup. Falls back to + :func:`numpy.isin` for other dtypes. + + Parameters + ---------- + arr : numpy.ndarray + Input array to test. + codes : sequence of int + Values to test for membership. + + Returns + ------- + numpy.ndarray + Boolean array of the same shape as *arr*. + + """ + if arr.dtype == np.uint8 or (arr.dtype.kind in "iu" and arr.max() < 65536): + lut = np.zeros(max(int(arr.max()) + 1, max(codes) + 1), dtype=bool) + lut[list(codes)] = True + return lut[arr] + return isin(arr, codes) + + +def fast_dilation(mask: NDArray, iterations: int) -> NDArray: + """ + Binary dilation using distance transform for large iteration counts. + + For ``iterations > 3``, uses :func:`scipy.ndimage.distance_transform_cdt` + with cityblock metric, which is equivalent to iterative cross-shaped + dilation but significantly faster. + + Parameters + ---------- + mask : numpy.ndarray + Boolean 2-D mask to dilate. + iterations : int + Number of dilation iterations. If 0, *mask* is returned unchanged. + + Returns + ------- + numpy.ndarray + Dilated boolean mask. + + """ + if iterations <= 0: + return mask + struct = generate_binary_structure(2, 1) + if iterations > 3: + dist = distance_transform_cdt(~mask, metric="cityblock") + return dist <= iterations + return dilation(mask, structure=struct, iterations=iterations) + + +def apply_exclusion_entry( + d: dict[str, Any], + masked: NDArray, + res: float, +) -> NDArray: + """ + Apply codes filter, inversion, and buffer dilation to a raster mask. + + Processes a single exclusion-container raster entry: filters by + land-use codes (or a callable), optionally inverts, and dilates by + the buffer distance. + + Parameters + ---------- + d : dict + Raster entry dict with keys ``"codes"``, ``"invert"``, ``"buffer"``. + masked : numpy.ndarray + Raster data array to filter. + res : float + Spatial resolution in the exclusion CRS, used to convert buffer + distance to dilation iterations. + + Returns + ------- + numpy.ndarray + Boolean exclusion mask (``True`` = excluded). + + """ + if d["codes"]: + if callable(d["codes"]): + masked_ = d["codes"](masked).astype(bool) + else: + masked_ = fast_isin(masked, d["codes"]) + else: + masked_ = masked.astype(bool) + + if d["invert"]: + masked_ = ~masked_ + if d["buffer"]: + iterations = int(d["buffer"] / res) + 1 + masked_ = fast_dilation(masked_, iterations) + return masked_ + + +class RasterCache: + """ + In-memory cache of raster data read from an ExclusionContainer. + + Reads each unique raster file once via ``raster.read(1)`` and stores + the full array, its affine transform, and CRS. Subsequent per-shape + reads are served by numpy slicing, avoiding repeated disk I/O. + + Parameters + ---------- + excluder : ExclusionContainer + Container whose raster files are pre-loaded. Files are opened + automatically if not already open. + + """ + + def __init__(self, excluder: ExclusionContainer) -> None: + if not excluder.all_open: + excluder.open_files() + + self._data: dict[str, tuple[NDArray, rio.Affine, CrsLike]] = {} + for d in excluder.rasters: + raster = d["raster"] + key = raster.name + if key in self._data: + continue + data = raster.read(1) + self._data[key] = (data, raster.transform, raster.crs) + + def window_read( + self, + raster: rio.DatasetReader, + geom: GeoSeries, + transform: rio.Affine | None, + shape: tuple[int, int] | None, + crs: CrsLike, + nodata: int = 255, + allow_no_overlap: bool = False, + ) -> tuple[NDArray, rio.Affine]: + """ + Read a geometry-bounded window from cached raster data. + + Computes the pixel window covering *geom*, slices it from the + in-memory array, and reprojects to the target grid when the + native transform/shape differ from *transform*/*shape*. + + Parameters + ---------- + raster : rasterio.DatasetReader + Open raster whose ``name`` is the cache lookup key. + geom : geopandas.GeoSeries + Geometry defining the spatial extent to read. + transform : rasterio.Affine or None + Target affine transform. ``None`` returns the native window + transform. + shape : tuple of int or None + Target array shape ``(rows, cols)``. ``None`` uses the + native window shape. + crs : CRS-like + Target CRS for reprojection. + nodata : int + Fill value for areas outside the raster extent. + allow_no_overlap : bool + Return a nodata-filled array when the geometry does not + overlap the raster instead of raising. + + Returns + ------- + data : numpy.ndarray + Raster values for the requested window. + transform : rasterio.Affine + Affine transform of the returned array. + + Raises + ------ + ValueError + If the geometry does not overlap the raster and + *allow_no_overlap* is ``False``. + + """ + key = raster.name + data, src_transform, src_crs = self._data[key] + + if geom.crs != src_crs: + geom = geom.to_crs(src_crs) + + bounds_arr = geom.total_bounds + res_x, res_y = abs(src_transform.a), abs(src_transform.e) + col_off = int((bounds_arr[0] - src_transform.c) / src_transform.a) + row_off = int((src_transform.f - bounds_arr[3]) / res_y) + col_end = int(np.ceil((bounds_arr[2] - src_transform.c) / src_transform.a)) + row_end = int(np.ceil((src_transform.f - bounds_arr[1]) / res_y)) + + h, w = data.shape + if col_off >= w or row_off >= h or col_end <= 0 or row_end <= 0: + if allow_no_overlap: + if transform is not None and shape is not None: + return np.full(shape, nodata, dtype=data.dtype), transform + fallback_t, fallback_s = padded_transform_and_shape(bounds_arr, res_x) + return np.full(fallback_s, nodata, dtype=data.dtype), fallback_t + raise ValueError("Input shapes do not overlap raster.") + + col_off = max(0, col_off) + row_off = max(0, row_off) + col_end = min(w, col_end) + row_end = min(h, row_end) + + window_data = data[row_off:row_end, col_off:col_end].copy() + window_transform = rio.Affine( + src_transform.a, + 0, + src_transform.c + col_off * src_transform.a, + 0, + src_transform.e, + src_transform.f + row_off * src_transform.e, + ) + + outside = geometry_mask(geom, window_data.shape, window_transform, invert=False) + window_data[outside] = nodata + + if transform is None or ( + window_transform == transform and window_data.shape == shape + ): + return window_data, window_transform + + assert shape is not None and crs is not None + dtype = data.dtype if data.dtype.kind in "iu" else np.float64 + dst = np.empty(shape, dtype=dtype) + gdal_logger = logging.getLogger("rasterio._err") + prev_level = gdal_logger.level + gdal_logger.setLevel(logging.ERROR) + try: + return rio.warp.reproject( + window_data, + dst, + src_crs=src_crs, + dst_crs=crs, + src_transform=window_transform, + dst_transform=transform, + src_nodata=nodata, + dst_nodata=nodata, + ) + finally: + gdal_logger.setLevel(prev_level) + + +def shape_availability_cached( + geometry: GeoSeries, + excluder: ExclusionContainer, + cache: RasterCache, +) -> tuple[NDArray, rio.Affine]: + """ + Compute eligible area using pre-loaded raster data. + + Equivalent to :func:`shape_availability` but reads raster data from + *cache* instead of disk, avoiding per-shape I/O overhead. + + Parameters + ---------- + geometry : geopandas.GeoSeries + Geometry of which the eligible area is computed. + excluder : ExclusionContainer + Container of exclusion rasters and geometries. + cache : RasterCache + Pre-loaded raster cache built from *excluder*. + + Returns + ------- + masked : numpy.ndarray + Boolean mask where ``True`` indicates eligible cells. + transform : rasterio.Affine + Affine transform of the mask. + + """ + bounds = rio.features.bounds(geometry) + transform, shape = padded_transform_and_shape(bounds, res=excluder.res) + exclusions = geometry_mask(geometry, shape, transform) + + raster_name: str | None = None + for d in excluder.rasters: + name = d["raster"].name + if name != raster_name: + raster_name = name + kwargs_keys = ["allow_no_overlap", "nodata"] + kwargs = {k: v for k, v in d.items() if k in kwargs_keys} + masked, transform = cache.window_read( + d["raster"], geometry, transform, shape, excluder.crs, **kwargs + ) + exclusions |= apply_exclusion_entry(d, masked, excluder.res) + + for d in excluder.geometries: + masked = ~geometry_mask(d["geometry"], shape, transform, invert=d["invert"]) + exclusions |= masked + + return ~exclusions, transform + + def maybe_swap_spatial_dims( ds: Dataset | DataArray, namex: str = "x", namey: str = "y" ) -> Dataset | DataArray: diff --git a/atlite/streaming.py b/atlite/streaming.py new file mode 100644 index 00000000..0f9486c4 --- /dev/null +++ b/atlite/streaming.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: Contributors to atlite +# +# SPDX-License-Identifier: MIT + +"""Streaming conversion backend with chunk-aligned I/O. + +Processes weather-to-energy conversions one time-chunk at a time so that +the full ``(time, y, x)`` grid is never materialised in memory. Only +cases that actually benefit from streaming (matrix aggregation or +temporal reduction) are handled; all other cases fall back to the +dask-backed path. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import xarray as xr + +from atlite.aggregate import ( + finalize_aggregated_result, + wrap_matrix_result, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + import pandas as pd + from scipy.sparse import spmatrix + + from atlite._types import DataArray + from atlite.cutout import Cutout + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class StreamSpec: + """Spatial / temporal metadata needed by the streaming loop.""" + + time: xr.DataArray + y: xr.DataArray + x: xr.DataArray + n_spatial: int + out_ny: int + out_nx: int + units: str + + @property + def n_time(self) -> int: + return len(self.time) + + +def optimal_chunk_size(ds: xr.Dataset) -> int: + """Derive the time-chunk size from the on-disk encoding of *ds*. + + Falls back to 2190 (≈ one quarter of hourly year data) when no + ``chunksizes`` metadata is found. + """ + time_sizes: set[int] = set() + for var in ds.data_vars: + chunksizes = ds[var].encoding.get("chunksizes") + if chunksizes and len(chunksizes) > 0: + time_sizes.add(chunksizes[0]) + return max(time_sizes) if time_sizes else 2190 + + +def can_stream(cutout: Cutout) -> bool: + """Return ``True`` when *cutout* is backed by an on-disk NetCDF file.""" + source = cutout.data.encoding.get("source") + return source is not None and cutout.path.is_file() + + +def has_streaming_benefit( + matrix: spmatrix | None, + aggregate_time_method: Literal["sum", "mean"] | None, +) -> bool: + """Return ``True`` when streaming can reduce peak memory over dask. + + Streaming helps when either a sparse-matrix aggregation collapses the + spatial grid or a temporal reduction (sum/mean) allows an accumulator + instead of a full output buffer. + """ + return matrix is not None or aggregate_time_method in ("sum", "mean") + + +def validate_output( + da_chunk: xr.DataArray, + ds_chunk: xr.Dataset, + convert_func: Callable[..., Any], +) -> bool: + """Check that *convert_func* produced streamable output. + + Returns ``False`` (with a debug log message) when the output is + missing a ``time`` dimension, has a different time length than the + input chunk, or contains unexpected extra dimensions. + """ + if "time" not in da_chunk.dims: + logger.debug( + "Streaming aborted: %s output has no 'time' dimension.", + convert_func.__name__, + ) + return False + + if len(da_chunk["time"]) != len(ds_chunk["time"]): + logger.debug( + "Streaming aborted: %s changed time dimension (%d → %d).", + convert_func.__name__, + len(ds_chunk["time"]), + len(da_chunk["time"]), + ) + return False + + extra_dims = set(da_chunk.dims) - {"time", "y", "x"} + if extra_dims: + logger.debug( + "Streaming aborted: %s has unexpected dims %s.", + convert_func.__name__, + extra_dims, + ) + return False + + return True + + +def build_stream_spec( + cutout: Cutout, + da_chunk: xr.DataArray, +) -> StreamSpec: + """Build a :class:`StreamSpec` from the cutout and a sample converted chunk.""" + y = cutout.data["y"] + x = cutout.data["x"] + out_ny = da_chunk.sizes.get("y", len(y)) + out_nx = da_chunk.sizes.get("x", len(x)) + return StreamSpec( + time=cutout.data["time"], + y=y, + x=x, + n_spatial=out_ny * out_nx, + out_ny=out_ny, + out_nx=out_nx, + units=da_chunk.attrs.get("units", ""), + ) + + +def init_buffers( + spec: StreamSpec, + matrix: spmatrix | None, +) -> tuple[np.ndarray | None, np.ndarray | None]: + """Allocate the output buffer(s) for the streaming loop. + + Returns ``(result_data, accum)`` where exactly one is non-``None``: + *result_data* for matrix aggregation, *accum* for temporal reduction. + """ + if matrix is not None: + return np.empty((spec.n_time, matrix.shape[0]), dtype=np.float64), None + return None, np.zeros((spec.out_ny, spec.out_nx), dtype=np.float64) + + +def finalize_matrix( + spec: StreamSpec, + result_data: np.ndarray, + matrix: spmatrix, + index: pd.Index, + per_unit: bool, + return_capacity: bool, + aggregate_time_method: Literal["sum", "mean"] | None, +) -> DataArray | tuple[DataArray, DataArray]: + """Wrap the matrix-aggregated buffer and apply per-unit / time aggregation.""" + result = wrap_matrix_result(result_data, spec.time, index) + return finalize_aggregated_result( + result, matrix, index, per_unit, return_capacity, aggregate_time_method + ) + + +def finalize_grid( + spec: StreamSpec, + accum: np.ndarray, + aggregate_time_method: Literal["sum", "mean"] | None, +) -> DataArray: + """Build a spatially-indexed DataArray from the temporal accumulator.""" + coords = {"y": spec.y, "x": spec.x} + values = accum if aggregate_time_method == "sum" else accum / spec.n_time + result = xr.DataArray(values, dims=("y", "x"), coords=coords) + if spec.units: + result.attrs["units"] = spec.units + return result + + +def stream_conversion( + cutout: Cutout, + convert_func: Callable[..., Any], + matrix: spmatrix | None, + index: pd.Index | None, + per_unit: bool, + return_capacity: bool, + aggregate_time: Literal["sum", "mean"] | None, + show_progress: bool, + convert_kwds: dict[str, Any], +) -> DataArray | tuple[DataArray, DataArray] | None: + """Execute *convert_func* on *cutout* one time-chunk at a time. + + Returns ``None`` when streaming offers no memory benefit for the + given arguments or when the conversion output is not streamable, so + the caller can fall back to the dask-backed path. + """ + if not has_streaming_benefit(matrix, aggregate_time): + return None + + source = cutout.data.encoding.get("source") + ds_eager = xr.open_dataset(source, chunks=None) + try: + return stream_inner( + ds_eager, + cutout, + convert_func, + matrix, + index, + per_unit, + return_capacity, + aggregate_time, + show_progress, + convert_kwds, + ) + finally: + ds_eager.close() + + +def stream_inner( + ds_eager: xr.Dataset, + cutout: Cutout, + convert_func: Callable[..., Any], + matrix: spmatrix | None, + index: pd.Index | None, + per_unit: bool, + return_capacity: bool, + aggregate_time_method: Literal["sum", "mean"] | None, + show_progress: bool, + convert_kwds: dict[str, Any], +) -> DataArray | tuple[DataArray, DataArray] | None: + """Core streaming loop over time-chunks of *ds_eager*. + + Reads one storage-aligned chunk at a time, runs *convert_func* + eagerly, and either multiplies by the sparse *matrix* or accumulates + into a temporal reducer. Returns ``None`` when the first chunk + reveals that the conversion is not streamable. + """ + chunk_size = optimal_chunk_size(ds_eager) + time_idx = cutout.data["time"] + n_time = len(time_idx) + + src_times = ds_eager["time"].values + target_times = time_idx.values + + matrix_t = matrix.T.tocsc() if matrix is not None else None + result_data = None + accum = None + spec: StreamSpec | None = None + + chunks_iter = list(range(0, n_time, chunk_size)) + n_chunks = len(chunks_iter) + + for i_chunk, start in enumerate(chunks_iter): + end = min(start + chunk_size, n_time) + chunk_times = target_times[start:end] + mask = np.isin(src_times, chunk_times) + idx = np.nonzero(mask)[0] + sl_src = slice(int(idx[0]), int(idx[-1]) + 1) + + ds_chunk = ds_eager.isel(time=sl_src) + da_chunk = convert_func(ds_chunk, **convert_kwds) + + if spec is None: + if not validate_output(da_chunk, ds_chunk, convert_func): + return None + spec = build_stream_spec(cutout, da_chunk) + result_data, accum = init_buffers(spec, matrix) + + chunk_values = da_chunk.values + sl_out = slice(start, end) + + if matrix is not None: + flat = chunk_values.reshape(chunk_values.shape[0], spec.n_spatial) + result_data[sl_out] = flat @ matrix_t + else: + accum += chunk_values.sum(axis=0) + + if show_progress: + logger.info( + "Streaming %s: chunk %d/%d", + convert_func.__name__, + i_chunk + 1, + n_chunks, + ) + + if matrix is not None: + return finalize_matrix( + spec, + result_data, + matrix, + index, + per_unit, + return_capacity, + aggregate_time_method, + ) + return finalize_grid(spec, accum, aggregate_time_method) diff --git a/pyproject.toml b/pyproject.toml index 5bfa288b..2103c208 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,8 @@ Documentation = "https://atlite.readthedocs.io/en/latest/" dev = ["pre-commit", "pytest", "pytest-cov", "matplotlib", "ruff", "mypy", "types-PyYAML"] +benchmark = ["dask[distributed]", "pyyaml", "geopandas", "shapely"] + docs = [ "numpydoc==1.8.0", "sphinx==8.0.2", @@ -75,7 +77,7 @@ docs = [ version_scheme = "no-guess-dev" [tool.setuptools.packages.find] -include = ["atlite"] +include = ["atlite", "atlite.*"] # Formatter and linter settings diff --git a/test/test_aggregate_time.py b/test/test_aggregate_time.py index 270fae91..94d819ed 100644 --- a/test/test_aggregate_time.py +++ b/test/test_aggregate_time.py @@ -167,3 +167,20 @@ def test_aggregate_time_false_raises(self, cutout): def test_aggregate_time_true_raises(self, cutout): with pytest.raises(ValueError, match="aggregate_time must be"): convert_and_aggregate(cutout, identity_convert, aggregate_time=True) # type: ignore[arg-type] + + def test_invalid_backend_raises(self, cutout): + with pytest.raises(ValueError, match="backend must be"): + convert_and_aggregate(cutout, identity_convert, backend="invalid") # type: ignore[arg-type] + + def test_dask_kwargs_require_dask_backend(self, cutout): + with pytest.raises(ValueError, match="dask_kwargs require backend='dask'"): + convert_and_aggregate( + cutout, + identity_convert, + backend="auto", + dask_kwargs={"scheduler": "threads"}, + ) + + def test_streaming_backend_requires_streamable_cutout(self, cutout): + with pytest.raises(ValueError, match="backend='streaming' requires"): + convert_and_aggregate(cutout, identity_convert, backend="streaming") diff --git a/test/test_creation.py b/test/test_creation.py index 055b089a..68d713a8 100755 --- a/test/test_creation.py +++ b/test/test_creation.py @@ -14,6 +14,7 @@ import numpy as np import pytest import rasterio as rio +import xarray as xr from xarray.testing import assert_equal from atlite import Cutout @@ -126,6 +127,27 @@ def test_auto_chunking(ref): assert_equal(cutout.coords.to_dataset(), ref.coords.to_dataset()) +def test_storage_aligned_time_chunking(tmp_path): + path = tmp_path / "storage_aligned.nc" + ds = xr.Dataset( + { + "temperature": xr.DataArray( + np.zeros((6, 2, 3)), + dims=["time", "y", "x"], + coords={"time": range(6), "y": range(2), "x": range(3)}, + ) + }, + attrs={"module": "era5", "prepared_features": []}, + ) + ds.to_netcdf(path, encoding={"temperature": {"chunksizes": (4, 2, 3)}}) + + cutout = Cutout(path) + + assert cutout.chunks == {"time": 4} + assert cutout.data.chunks is not None + assert cutout.data.chunksizes["time"] == (4, 2) + + def test_dx_dy_dt(): """ Test the properties dx, dy, dt of atlite.Cutout. diff --git a/test/test_gis.py b/test/test_gis.py index 6d003f6b..63e21570 100755 --- a/test/test_gis.py +++ b/test/test_gis.py @@ -28,10 +28,14 @@ from atlite import Cutout from atlite.gis import ( ExclusionContainer, + RasterCache, + fast_dilation, + fast_isin, pad_extent, padded_transform_and_shape, regrid, shape_availability, + shape_availability_cached, ) TIME = "2013-01-01" @@ -635,6 +639,111 @@ def test_shape_availability_exclude_raster_codes(ref, raster_codes): assert ratio == masked.sum() / masked.size +@pytest.mark.parametrize( + "codes,expected_ratio", + [ + (range(20), 0.8), + (range(50), 0.5), + ([99], 0.99), + ], +) +def test_fast_isin(codes, expected_ratio): + rng = np.random.default_rng(42) + arr = rng.integers(0, 100, size=(100, 100), dtype=np.uint8) + result = fast_isin(arr, list(codes)) + assert result.dtype == bool + assert np.array_equal(result, np.isin(arr, list(codes))) + + +def test_fast_isin_float_fallback(): + arr = np.array([0.5, 1.5, 2.5]) + result = fast_isin(arr, [1, 2]) + assert np.array_equal(result, np.isin(arr, [1, 2])) + + +@pytest.mark.parametrize("iterations", [1, 3, 5, 11]) +def test_fast_dilation(iterations): + rng = np.random.default_rng(42) + mask = rng.random((200, 200)) < 0.05 + from scipy.ndimage import binary_dilation, generate_binary_structure + + struct = generate_binary_structure(2, 1) + expected = binary_dilation(mask, structure=struct, iterations=iterations) + result = fast_dilation(mask, iterations) + assert np.array_equal(result, expected) + + +def test_fast_dilation_zero_iterations(): + mask = np.array([[True, False], [False, True]]) + assert np.array_equal(fast_dilation(mask, 0), mask) + + +def test_raster_cache_deduplicates(raster): + excluder = ExclusionContainer(crs=4326, res=0.01) + excluder.add_raster(raster) + excluder.add_raster(raster, codes=[1], invert=True) + excluder.open_files() + cache = RasterCache(excluder) + assert len(cache._data) == 1 + + +def test_raster_cache_no_overlap(raster): + excluder = ExclusionContainer(crs=4326, res=0.01) + excluder.add_raster(raster, allow_no_overlap=True) + excluder.open_files() + cache = RasterCache(excluder) + far_away = gpd.GeoSeries([box(90, 90, 91, 91)], crs=4326) + transform, shape = padded_transform_and_shape(far_away.total_bounds, 0.01) + result, _ = cache.window_read( + excluder.rasters[0]["raster"], + far_away, + transform, + shape, + crs=4326, + allow_no_overlap=True, + ) + assert (result == 255).all() + + +def test_shape_availability_cached_matches_original(ref, raster): + shapes = gpd.GeoSeries([box(X0, Y0, X1, Y1)], crs=ref.crs) + res = 0.01 + excluder = ExclusionContainer(ref.crs, res=res) + excluder.add_raster(raster) + excluder.add_raster(raster, codes=[1], invert=True) + excluder.open_files() + + shapes_proj = shapes.to_crs(excluder.crs) + cache = RasterCache(excluder) + cached, trans_c = shape_availability_cached(shapes_proj, excluder, cache) + orig, trans_o = shape_availability(shapes_proj, excluder) + assert trans_c == trans_o + assert np.array_equal(cached, orig) + + +def test_availability_matrix_threaded_with_geometry(ref, raster): + shapes = gpd.GeoSeries( + [ + box(X0 + 1, Y0 + 1, X1 - 1, Y0 / 2 + Y1 / 2), + box(X0 + 1, Y0 / 2 + Y1 / 2, X1 - 1, Y1 - 1), + ], + crs=ref.crs, + ).rename_axis("shape") + exclude = gpd.GeoSeries( + [box(X0 / 2 + X1 / 2, Y0 / 2 + Y1 / 2, X1, Y1)], crs=ref.crs + ) + excluder_s = ExclusionContainer(ref.crs, res=0.01) + excluder_s.add_raster(raster) + excluder_s.add_geometry(exclude) + ds_serial = ref.availabilitymatrix(shapes, excluder_s) + + excluder_p = ExclusionContainer(ref.crs, res=0.01) + excluder_p.add_raster(raster) + excluder_p.add_geometry(exclude) + ds_parallel = ref.availabilitymatrix(shapes, excluder_p, nprocesses=2) + assert np.allclose(ds_serial, ds_parallel) + + def test_plot_shape_availability(ref, raster): """Test plotting of shape availability.""" shapes = gpd.GeoSeries([box(X0, Y0, X1, Y1)], crs=ref.crs)