diff --git a/climada/hazard/forecast.py b/climada/hazard/forecast.py index 5130e66af1..b09e1a44e3 100644 --- a/climada/hazard/forecast.py +++ b/climada/hazard/forecast.py @@ -104,3 +104,64 @@ def _check_sizes(self): num_entries = len(self.event_id) size(exp_len=num_entries, var=self.member, var_name="Forecast.member") size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time") + + def select( + self, + member=None, + lead_time=None, + event_names=None, + event_id=None, + date=None, + orig=None, + reg_id=None, + extent=None, + reset_frequency=False, + ): + """Select entries based on the parameters and return a new instance. + + The selection will contain the intersection of all given parameters. + + Parameters + ---------- + member : Sequence of ints + Ensemble members to select + lead_time : Sequence of numpy.timedelta64 + Lead times to select + + Returns + ------- + HazardForecast + + See Also + -------- + :py:meth:`~climada.hazard.base.Hazard.select` + """ + if member is not None or lead_time is not None: + mask_member = ( + self.idx_member(member) + if member is not None + else np.full_like(self.member, True, dtype=bool) + ) + mask_lead_time = ( + self.idx_lead_time(lead_time) + if lead_time is not None + else np.full_like(self.lead_time, True, dtype=bool) + ) + event_id_from_forecast_mask = np.asarray(self.event_id)[ + (mask_member & mask_lead_time) + ] + event_id = ( + np.intersect1d(event_id, event_id_from_forecast_mask) + if event_id is not None + else event_id_from_forecast_mask + ) + + return super().select( + event_names=event_names, + event_id=event_id, + date=date, + orig=orig, + reg_id=reg_id, + extent=extent, + reset_frequency=reset_frequency, + ) diff --git a/climada/hazard/test/test_forecast.py b/climada/hazard/test/test_forecast.py index b102ee2d17..ac1a726965 100644 --- a/climada/hazard/test/test_forecast.py +++ b/climada/hazard/test/test_forecast.py @@ -107,46 +107,94 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member): npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member])) -@pytest.mark.parametrize( - "var, var_select", - [("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")], -) -def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select): - """Check if Hazard.select works on the derived class""" - - select_mask = np.array([3, 2]) - ordered_select_mask = np.array([3, 2]) - if var == "date": - # Date needs to be a valid delta - select_mask = np.array([2, 3]) - ordered_select_mask = np.array([2, 3]) - - var_value = np.array(haz_kwargs[var])[select_mask] - # event_name is a list, convert to numpy array for indexing - haz_fc_sel = haz_fc.select(**{var_select: var_value}) - # Note: order is preserved - npt.assert_array_equal( - haz_fc_sel.event_id, - haz_fc.event_id[ordered_select_mask], - ) - npt.assert_array_equal( - haz_fc_sel.event_name, - np.array(haz_fc.event_name)[ordered_select_mask], - ) - npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask]) - npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask]) - npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask]) - npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask]) - npt.assert_array_equal( - haz_fc_sel.intensity.todense(), - haz_fc.intensity.todense()[ordered_select_mask], - ) - npt.assert_array_equal( - haz_fc_sel.fraction.todense(), - haz_fc.fraction.todense()[ordered_select_mask], - ) +class TestSelect: - assert haz_fc_sel.centroids == haz_fc.centroids + @pytest.mark.parametrize( + "var, var_select", + [("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")], + ) + def test_base_class_select( + self, haz_fc, lead_time, member, haz_kwargs, var, var_select + ): + """Check if Hazard.select works on the derived class""" + + select_mask = np.array([3, 2]) + ordered_select_mask = np.array([3, 2]) + if var == "date": + # Date needs to be a valid delta + select_mask = np.array([2, 3]) + ordered_select_mask = np.array([2, 3]) + + var_value = np.array(haz_kwargs[var])[select_mask] + # event_name is a list, convert to numpy array for indexing + haz_fc_sel = haz_fc.select(**{var_select: var_value}) + # Note: order is preserved + npt.assert_array_equal( + haz_fc_sel.event_id, + haz_fc.event_id[ordered_select_mask], + ) + npt.assert_array_equal( + haz_fc_sel.event_name, + np.array(haz_fc.event_name)[ordered_select_mask], + ) + npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask]) + npt.assert_array_equal( + haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask] + ) + npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask]) + npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask]) + npt.assert_array_equal( + haz_fc_sel.intensity.todense(), + haz_fc.intensity.todense()[ordered_select_mask], + ) + npt.assert_array_equal( + haz_fc_sel.fraction.todense(), + haz_fc.fraction.todense()[ordered_select_mask], + ) + + assert haz_fc_sel.centroids == haz_fc.centroids + + def test_derived_select_single(self, haz_fc, lead_time, member): + haz_fc_select = haz_fc.select(member=[3, 0]) + idx = np.array([0, 3]) + npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx]) + npt.assert_array_equal(haz_fc_select.member, member[idx]) + npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx]) + + haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])]) + npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx]) + npt.assert_array_equal(haz_fc_select.member, member[idx]) + npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx]) + + def test_derived_select_intersections(self, haz_fc, lead_time, member, haz_kwargs): + haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2]) + npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])]) + + haz_fc_select = haz_fc.select( + event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3] + ) + npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])]) + + # Test "outer" + haz_fc2 = HazardForecast( + lead_time=lead_time, member=np.zeros_like(member, dtype="int"), **haz_kwargs + ) + haz_fc_select = haz_fc2.select(event_id=[1, 2, 4], member=[0]) + npt.assert_array_equal(haz_fc_select.event_id, [1, 2, 4]) + npt.assert_array_equal(haz_fc_select.member, [0, 0, 0]) + + def test_derived_select_null(self, haz_fc, haz_kwargs): + haz_fc_select = haz_fc.select() + assert_hazard_kwargs(haz_fc_select, **haz_kwargs) + + with pytest.raises(IndexError): + haz_fc.select(event_id=[-1]) + with pytest.raises(IndexError): + haz_fc.select(member=[-1]) + with pytest.raises(IndexError): + haz_fc.select( + lead_time=[np.timedelta64("2", "Y").astype("timedelta64[ns]")] + ) def test_write_read_hazard_forecast(haz_fc, tmp_path):