Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/modelskill/model/vertical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Literal, Sequence
from typing import Any, Sequence

import xarray as xr
import pandas as pd
Expand Down Expand Up @@ -35,11 +35,6 @@ class VerticalModelResult(TimeSeries):
zonal coordinate of point position, inferred from data if not given, else None
quantity : Quantity, optional
Model quantity, for MIKE files this is inferred from the EUM information
keep_duplicates : (str, bool), optional
Strategy for handling duplicate timestamps (wraps xarray.Dataset.drop_duplicates)
"first" to keep first occurrence, "last" to keep last occurrence,
False to drop all duplicates, "offset" to add milliseconds to
consecutive duplicates, by default "first"
aux_items : list[int | str] | None, optional
Auxiliary items, by default None
"""
Expand All @@ -54,7 +49,6 @@ def __init__(
z_item: str | int = 0,
x: float | None = None,
y: float | None = None,
keep_duplicates: Literal["first", "last", False] = "first",
aux_items: Sequence[int | str] | None = None,
) -> None:
if not self._is_input_validated(data):
Expand All @@ -66,7 +60,6 @@ def __init__(
z_item=z_item,
x=x,
y=y,
keep_duplicates=keep_duplicates,
aux_items=aux_items,
)
assert isinstance(data, xr.Dataset)
Expand Down
4 changes: 0 additions & 4 deletions src/modelskill/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ class VerticalObservation(Observation):
User-defined name for identification in plots and summaries.
weight : float, optional
Weighting factor for skill scores, by default 1.0.
keep_duplicates : {"first", "last", False}, optional
Strategy for handling duplicate timestamps/z pairs.
quantity : Quantity, optional
Physical quantity metadata used for validation against model results.
aux_items : list[int | str], optional
Expand Down Expand Up @@ -445,7 +443,6 @@ def __init__(
z_item: int | str | None = 0,
name: str | None = None,
weight: float = 1.0,
keep_duplicates: Literal["first", "last", False] = "first",
quantity: Quantity | None = None,
aux_items: list[int | str] | None = None,
attrs: dict | None = None,
Expand All @@ -460,7 +457,6 @@ def __init__(
z_item=z_item,
x=x,
y=y,
keep_duplicates=keep_duplicates,
)
assert isinstance(data, xr.Dataset)
super().__init__(data=data, weight=weight, attrs=attrs)
Expand Down
19 changes: 7 additions & 12 deletions src/modelskill/timeseries/_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections.abc import Hashable
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, get_args, Optional, List, Sequence
import warnings
from typing import get_args, Optional, List, Sequence
import pandas as pd
import xarray as xr
from ._coords import XYZCoords
Expand Down Expand Up @@ -88,7 +87,6 @@ def _parse_vertical_input(
z_item: str | int | None,
x: float | None = None,
y: float | None = None,
keep_duplicates: Literal["first", "last", False] = "first",
aux_items: Optional[Sequence[int | str]] = None,
) -> xr.Dataset:
assert isinstance(
Expand Down Expand Up @@ -142,16 +140,13 @@ def _parse_vertical_input(

ds = ds.rename({sel_items.z: "z"})

# keep first, last or none of duplicate (time, z) pairs
idx_df = pd.DataFrame({"time": ds["time"].to_index(), "z": ds["z"].values})

keep_mask = ~idx_df.duplicated(subset=["time", "z"], keep=keep_duplicates)

n_removed = int((~keep_mask).sum())
ds = ds.isel(time=keep_mask.values)
if n_removed > 0:
warnings.warn(
f"Removed {n_removed} duplicate (time, z) entries with keep={keep_duplicates}"
n_duplicates = int(idx_df.duplicated(subset=["time", "z"]).sum())
if n_duplicates > 0:
raise ValueError(
f"Input contains {n_duplicates} duplicate (time, z) entries. "
"Vertical profiles must have a unique depth per timestamp; "
"deduplicate the input before constructing the object."
)

ds = ds.dropna(dim="time", subset=["z"]) # remove times with z as nan
Expand Down
56 changes: 11 additions & 45 deletions tests/model/test_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,6 @@ def vertical_model_df() -> pd.DataFrame:
)


@pytest.fixture
def vertical_model_df_duplicates() -> pd.DataFrame:
return pd.DataFrame(
{
"z": [-5.0, -5.0, -4.0, -4.0, -3.0],
"Salinity": [30.0, 300.0, 31.0, 310.0, 32.0],
},
index=pd.to_datetime(
[
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 01:00:00",
]
),
)


@pytest.fixture
def vertical_model_df_aux() -> pd.DataFrame:
return pd.DataFrame(
Expand Down Expand Up @@ -136,38 +117,23 @@ def test_item_named_z(self, dfs0_ds):
with pytest.raises(ValueError, match="name 'z' is reserved "):
_ = ms.VerticalModelResult(ds_test)

# ===============
# test arguments options for handling duplicates
# ===============
@pytest.mark.parametrize(
"keep_duplicates,expected_removed,expected_z,expected_values",
[
("first", 2, [-5.0, -4.0, -3.0], [30.0, 31.0, 32.0]),
("last", 2, [-5.0, -4.0, -3.0], [300.0, 310.0, 32.0]),
(False, 4, [-3.0], [32.0]),
],
)
def test_vertical_model_keep_duplicates_modes(
self,
vertical_model_df_duplicates,
keep_duplicates,
expected_removed,
expected_z,
expected_values,
):
with pytest.warns(UserWarning, match=f"Removed {expected_removed} duplicate"):
mr = ms.VerticalModelResult(
vertical_model_df_duplicates,
def test_duplicate_time_z_pairs_raises(self):
df = pd.DataFrame(
{
"z": [-5.0, -4.0, -4.0],
"Salinity": [30.0, 31.0, 31.5],
},
index=[pd.Timestamp("2019-01-01")] * 3,
)
with pytest.raises(ValueError, match="duplicate \\(time, z\\) entries"):
ms.VerticalModelResult(
df,
item="Salinity",
z_item="z",
x=12.0,
y=55.0,
keep_duplicates=keep_duplicates,
)

assert list(mr.data["z"].values) == expected_z
assert list(mr.data[mr.name].values) == expected_values

# aux items
def test_vertical_model_aux_items_preserved_and_tagged(self, vertical_model_df_aux):
mr = ms.VerticalModelResult(
Expand Down
66 changes: 2 additions & 64 deletions tests/observation/test_vertical_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,6 @@ def _vertical_df() -> pd.DataFrame:
)


@pytest.fixture
def _vertical_df_duplicates() -> pd.DataFrame:
return pd.DataFrame(
{
"z": [-5.0, -5.0, -4.0, -4.0, -3.0],
"value": [1.0, 10.0, 2.0, 20.0, 7.0],
},
index=pd.to_datetime(
[
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 01:00:00",
]
),
)


@pytest.fixture
def _vertical_df_aux() -> pd.DataFrame:
return pd.DataFrame(
Expand Down Expand Up @@ -114,15 +95,15 @@ def test_with_and_without_item_arg(self):
fn = Path("tests/testdata/vertical/VerticalProfile_obs1.dfs0")
assert isinstance(ms.observation(fn, z_item="z"), ms.VerticalObservation)

def test_duplicated_time_z_pairs(self):
def test_duplicate_time_z_pairs_raises(self):
df = pd.DataFrame(
{
"z": [-5.0, -4.0, -4.0],
"value": [1.0, 1.1, 1.2],
},
index=[pd.Timestamp("2019-01-01")] * 3,
)
with pytest.warns(UserWarning, match="Removed 1 duplicate"):
with pytest.raises(ValueError, match="duplicate \\(time, z\\) entries"):
ms.VerticalObservation(
df,
item="value",
Expand All @@ -131,49 +112,6 @@ def test_duplicated_time_z_pairs(self):
y=55.0,
)

def test_keep_duplicates_last_is_applied(self, _vertical_df_duplicates):
with pytest.warns(UserWarning, match="Removed 2 duplicate"):
obs = ms.VerticalObservation(
_vertical_df_duplicates,
item="value",
z_item="z",
x=12.0,
y=55.0,
keep_duplicates="last",
)

assert list(obs.data["z"].values) == [-5.0, -4.0, -3.0]
assert list(obs.data["value"].values) == [10.0, 20.0, 7.0]

@pytest.mark.parametrize(
"keep_duplicates,expected_removed,expected_z,expected_values",
[
("first", 2, [-5.0, -4.0, -3.0], [1.0, 2.0, 7.0]),
("last", 2, [-5.0, -4.0, -3.0], [10.0, 20.0, 7.0]),
(False, 4, [-3.0], [7.0]),
],
)
def test_keep_duplicates_modes(
self,
_vertical_df_duplicates,
keep_duplicates,
expected_removed,
expected_z,
expected_values,
):
with pytest.warns(UserWarning, match=f"Removed {expected_removed} duplicate"):
obs = ms.VerticalObservation(
_vertical_df_duplicates,
item="value",
z_item="z",
x=12.0,
y=55.0,
keep_duplicates=keep_duplicates,
)

assert list(obs.data["z"].values) == expected_z
assert list(obs.data["value"].values) == expected_values

def test_single_item_input_raises(self):
df = pd.DataFrame(
{"value": [1.0, 1.1, 1.2]},
Expand Down
Loading