Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions climada/engine/impact.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,8 @@ def write_attribute(group, name, value):

def write_dataset(group, name, value):
"""Write a dataset"""
if name == "lead_time":
value = value.astype("timedelta64[ns]").astype("int64")
group.create_dataset(name, data=value, dtype=_str_type_helper(value))

def write_dict(group, name, value):
Expand Down Expand Up @@ -1618,7 +1620,9 @@ def read_excel(self, *args, **kwargs):
self.__dict__ = Impact.from_excel(*args, **kwargs).__dict__

@classmethod
def from_hdf5(cls, file_path: Union[str, Path]):
def from_hdf5(
cls, file_path: Union[str, Path], *, add_scalar_attrs=None, add_array_attrs=None
):
"""Create an impact object from an H5 file.

This assumes a specific layout of the file. If values are not found in the
Expand Down Expand Up @@ -1663,6 +1667,10 @@ def from_hdf5(cls, file_path: Union[str, Path]):
----------
file_path : str or Path
The file path of the file to read.
add_scalar_attrs : Iterable of str, optional
Scalar attributes to read from file. Defaults to None.
add_array_attrs : Iterable of str, optional
Array attributes to read from file. Defaults to None.

Returns
-------
Expand Down Expand Up @@ -1691,17 +1699,27 @@ def from_hdf5(cls, file_path: Union[str, Path]):
# Scalar attributes
scalar_attrs = set(
("crs", "tot_value", "unit", "aai_agg", "frequency_unit", "haz_type")
).intersection(file.attrs.keys())
)
if add_scalar_attrs is not None:
scalar_attrs = scalar_attrs.union(add_scalar_attrs)
scalar_attrs = scalar_attrs.intersection(file.attrs.keys())
kwargs.update({attr: file.attrs[attr] for attr in scalar_attrs})

# Array attributes
# NOTE: Need [:] to copy array data. Otherwise, it would be a view that is
# invalidated once we close the file.
array_attrs = set(
("event_id", "date", "coord_exp", "eai_exp", "at_event", "frequency")
).intersection(file.keys())
)
if add_array_attrs is not None:
array_attrs = array_attrs.union(add_array_attrs)
array_attrs = array_attrs.intersection(file.keys())
kwargs.update({attr: file[attr][:] for attr in array_attrs})

# correct lead_time attribut to timedelta
if "lead_time" in kwargs:
kwargs["lead_time"] = np.array(file["lead_time"][:]).astype(
"timedelta64[ns]"
)
# Special handling for 'event_name' because it should be a list of strings
if "event_name" in file:
# pylint: disable=no-member
Expand Down
58 changes: 58 additions & 0 deletions climada/engine/impact_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"""

import logging
from pathlib import Path
from typing import Union

import numpy as np
import scipy.sparse as sparse
Expand Down Expand Up @@ -173,6 +175,62 @@ def calc_freq_curve(self, return_per=None):
LOGGER.error("calc_freq_curve is not defined for ImpactForecast")
raise NotImplementedError("calc_freq_curve is not defined for ImpactForecast")

@classmethod
def from_hdf5(cls, file_path: Union[str, Path]):
"""Create an ImpactForecast object from an H5 file.

This assumes a specific layout of the file. If values are not found in the
expected places, they will be set to the default values for an ``Impact`` object.

The following H5 file structure is assumed (H5 groups are terminated with ``/``,
attributes are denoted by ``.attrs/``)::

file.h5
├─ at_event
├─ coord_exp
├─ eai_exp
├─ event_id
├─ event_name
├─ frequency
├─ imp_mat
├─ lead_time
├─ member
├─ .attrs/
│ ├─ aai_agg
│ ├─ crs
│ ├─ frequency_unit
│ ├─ haz_type
│ ├─ tot_value
│ ├─ unit

As per the :py:func:`climada.engine.impact.Impact.__init__`, any of these entries
is optional. If it is not found, the default value will be used when constructing
the Impact.

The impact matrix ``imp_mat`` can either be an H5 dataset, in which case it is
interpreted as dense representation of the matrix, or an H5 group, in which case
the group is expected to contain the following data for instantiating a
`scipy.sparse.csr_matrix <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html>`_::

imp_mat/
├─ data
├─ indices
├─ indptr
├─ .attrs/
│ ├─ shape

Parameters
----------
file_path : str or Path
The file path of the file to read.

Returns
-------
imp : ImpactForecast
ImpactForecast with data from the given file
"""
return super().from_hdf5(file_path, add_array_attrs={"member", "lead_time"})

def _check_sizes(self):
"""Check sizes of forecast data vs. impact data.

Expand Down
30 changes: 30 additions & 0 deletions climada/engine/test/test_impact_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,36 @@ def test_impact_forecast_blocked_methods(impact_forecast):
impact_forecast.calc_freq_curve(np.array([10, 50, 100]))


@pytest.mark.parametrize("dense", [True, False])
def test_write_read_hdf5(impact_forecast, tmp_path, dense):

file_name = tmp_path / "test_hazard_forecast.h5"
# replace dummy_impact event_names with strings
impact_forecast.event_name = [str(name) for name in impact_forecast.event_name]
impact_forecast.write_hdf5(file_name, dense_imp_mat=dense)

def compare_attr(obj, attr):
actual = getattr(obj, attr)
expected = getattr(impact_forecast, attr)
if isinstance(actual, csr_matrix):
npt.assert_array_equal(actual.todense(), expected.todense())
else:
npt.assert_array_equal(actual, expected)

# Read ImpactForecast
impact_forecast_read = ImpactForecast.from_hdf5(file_name)
assert impact_forecast_read.lead_time.dtype.kind == np.dtype("timedelta64").kind
for attr in impact_forecast.__dict__.keys():
compare_attr(impact_forecast_read, attr)

# Read Impact
impact_read = Impact.from_hdf5(file_name)
for attr in impact_read.__dict__.keys():
compare_attr(impact_read, attr)
assert "member" not in impact_read.__dict__
assert "lead_time" not in impact_read.__dict__


@pytest.fixture
def impact_forecast_stats(impact_kwargs, lead_time, member):
max_index = 4
Expand Down
Loading