diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index 4961f824a..bf02af1a2 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -348,6 +348,40 @@ def sel( return cc + def drop( + self, + model: Optional[str] = None, + observation: Optional[str] = None, + quantity: Optional[str] = None, + ) -> "ComparerCollection": + """Drop data based on model, observation and/or quantity. + + Parameters + ---------- + model : str, optional + Model name. If None, all models are kept. + observation : str, optional + Observation name. If None, all observations are kept. + quantity : str, optional + Quantity name. If None, all quantities are kept. + """ + + remaining_models = ( + [m for m in self.mod_names if m not in model] if model else None + ) + remaing_observations = ( + [o for o in self.obs_names if o not in observation] if observation else None + ) + remaining_quantities = ( + [q for q in self.quantity_names if q not in quantity] if quantity else None + ) + + return self.sel( + model=remaining_models, + observation=remaing_observations, + quantity=remaining_quantities, + ) + def filter_by_attrs(self, **kwargs: Any) -> "ComparerCollection": """Filter by comparer attrs similar to xarray.Dataset.filter_by_attrs diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 1c4b5a475..ada3e902a 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -899,6 +899,31 @@ def sel( d = d.isel(time=mask) return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data) + def drop( + self, + model: IdxOrNameTypes, + ) -> "Comparer": + """Drop data based on model. + + Parameters + ---------- + model : str, optional + Model name or index. + """ + if isinstance(model, (str, int)): + models = [model] + else: + models = list(model) + mod_names = [_get_name(m, self.mod_names) for m in models] + dropped_models = [m for m in self.mod_names if m in mod_names] + d = self.data.drop_vars(dropped_models) + raw_mod_data = { + m: self.raw_mod_data[m] + for m in self.raw_mod_data.keys() + if m not in dropped_models + } + return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data) + def where( self, cond: Union[bool, np.ndarray, xr.DataArray], diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 2d8691d0c..cd2f6e8a3 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -462,6 +462,12 @@ def test_pc_sel_model(pc): assert np.all(pc2.raw_mod_data["m2"] == pc.raw_mod_data["m2"]) +def test_pc_drop_model(pc): + pc2 = pc.drop(model="m2") + assert "m2" not in pc2.mod_names + assert "m1" in pc2.mod_names + + def test_pc_sel_model_first(pc): pc2 = pc.sel(model=0) assert pc2.n_points == 5 diff --git a/tests/test_comparercollection.py b/tests/test_comparercollection.py index 25671fa74..6f450ec96 100644 --- a/tests/test_comparercollection.py +++ b/tests/test_comparercollection.py @@ -131,6 +131,18 @@ def test_cc_sel_model_m3(cc): assert cc2.n_models == 1 +def test_cc_drop_model(cc): + cc2 = cc.drop(model="m1") + assert "m1" not in cc2.mod_names + assert "m3" in cc2.mod_names + + +def test_cc_drop_observation(cc): + cc2 = cc.drop(observation="fake point obs") + assert "fake point obs" not in cc2.obs_names + assert "fake track obs" in cc2.obs_names + + def test_cc_sel_model_last(cc): # last is m3 which is not in the first comparer cc2 = cc.sel(model=-1) diff --git a/tests/test_multivariable_compare.py b/tests/test_multivariable_compare.py index c432f78d8..d176f6c43 100644 --- a/tests/test_multivariable_compare.py +++ b/tests/test_multivariable_compare.py @@ -120,6 +120,17 @@ def test_mv_mm_skill(cc): assert pytest.approx(df.loc[idx].rmse) == 1.30535897 +def test_drop_quantity(cc) -> None: + cc2 = cc.drop(quantity="Wind speed") + + # Original object is not changed + assert "Wind speed" in cc.quantity_names + + # New object does not contain the dropped quantity + assert "Wind speed" not in cc2.quantity_names + assert "Significant wave height" in cc.quantity_names + + def test_mv_mm_mean_skill(cc): df = cc.mean_skill().to_dataframe() assert df.index.names[0] == "model"