Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,40 @@ def sel(

return cc

def drop(
self,
model: Optional[str] = None,
observation: Optional[str] = None,
quantity: Optional[str] = None,
) -> "ComparerCollection":
"""Drop data based on model, observation and/or quantity.

Parameters
----------
model : str, optional
Model name. If None, all models are kept.
observation : str, optional
Observation name. If None, all observations are kept.
quantity : str, optional
Quantity name. If None, all quantities are kept.
"""

remaining_models = (
[m for m in self.mod_names if m not in model] if model else None
)
remaing_observations = (
[o for o in self.obs_names if o not in observation] if observation else None
)
remaining_quantities = (
[q for q in self.quantity_names if q not in quantity] if quantity else None
)

return self.sel(
model=remaining_models,
observation=remaing_observations,
quantity=remaining_quantities,
)

def filter_by_attrs(self, **kwargs: Any) -> "ComparerCollection":
"""Filter by comparer attrs similar to xarray.Dataset.filter_by_attrs

Expand Down
25 changes: 25 additions & 0 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,31 @@ def sel(
d = d.isel(time=mask)
return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data)

def drop(
self,
model: IdxOrNameTypes,
) -> "Comparer":
"""Drop data based on model.

Parameters
----------
model : str, optional
Model name or index.
"""
if isinstance(model, (str, int)):
models = [model]
else:
models = list(model)
mod_names = [_get_name(m, self.mod_names) for m in models]
dropped_models = [m for m in self.mod_names if m in mod_names]
d = self.data.drop_vars(dropped_models)
raw_mod_data = {
m: self.raw_mod_data[m]
for m in self.raw_mod_data.keys()
if m not in dropped_models
}
return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data)

def where(
self,
cond: Union[bool, np.ndarray, xr.DataArray],
Expand Down
6 changes: 6 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,12 @@ def test_pc_sel_model(pc):
assert np.all(pc2.raw_mod_data["m2"] == pc.raw_mod_data["m2"])


def test_pc_drop_model(pc):
pc2 = pc.drop(model="m2")
assert "m2" not in pc2.mod_names
assert "m1" in pc2.mod_names


def test_pc_sel_model_first(pc):
pc2 = pc.sel(model=0)
assert pc2.n_points == 5
Expand Down
12 changes: 12 additions & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ def test_cc_sel_model_m3(cc):
assert cc2.n_models == 1


def test_cc_drop_model(cc):
cc2 = cc.drop(model="m1")
assert "m1" not in cc2.mod_names
assert "m3" in cc2.mod_names


def test_cc_drop_observation(cc):
cc2 = cc.drop(observation="fake point obs")
assert "fake point obs" not in cc2.obs_names
assert "fake track obs" in cc2.obs_names


def test_cc_sel_model_last(cc):
# last is m3 which is not in the first comparer
cc2 = cc.sel(model=-1)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_multivariable_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def test_mv_mm_skill(cc):
assert pytest.approx(df.loc[idx].rmse) == 1.30535897


def test_drop_quantity(cc) -> None:
cc2 = cc.drop(quantity="Wind speed")

# Original object is not changed
assert "Wind speed" in cc.quantity_names

# New object does not contain the dropped quantity
assert "Wind speed" not in cc2.quantity_names
assert "Significant wave height" in cc.quantity_names


def test_mv_mm_mean_skill(cc):
df = cc.mean_skill().to_dataframe()
assert df.index.names[0] == "model"
Expand Down