Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def match(
gtype: Optional[GeometryTypes] = None,
max_model_gap: Optional[float] = None,
spatial_method: Optional[str] = None,
temporal_method: str = "linear",
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
) -> Comparer: ...

Expand All @@ -195,6 +196,7 @@ def match(
gtype: Optional[GeometryTypes] = None,
max_model_gap: Optional[float] = None,
spatial_method: Optional[str] = None,
temporal_method: str = "linear",
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
) -> ComparerCollection: ...

Expand All @@ -208,6 +210,7 @@ def match(
gtype=None,
max_model_gap=None,
spatial_method: Optional[str] = None,
temporal_method: str = "linear",
spatial_tolerance: float = 1e-3,
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
):
Expand Down Expand Up @@ -243,6 +246,11 @@ def match(
'inverse_distance' (with 5 nearest points), by default "inverse_distance".
- For GridModelResult, passed to xarray.interp() as method argument,
by default 'linear'.
temporal_method : str, optional
Temporal interpolation method passed to xarray.interp(), by default 'linear'
Valid options are: "akima", "barycentric", "cubic", "krogh", "linear",
"makima", "nearest", "pchip", "polynomial", "quadratic",
"quintic", "slinear", "spline", "zero".
spatial_tolerance : float, optional
Spatial tolerance (in the units of the coordinate system) for matching
model track points to observation track points. Model points outside
Expand Down Expand Up @@ -271,6 +279,7 @@ def match(
gtype=gtype,
max_model_gap=max_model_gap,
spatial_method=spatial_method,
temporal_method=temporal_method,
spatial_tolerance=spatial_tolerance,
obs_no_overlap=obs_no_overlap,
)
Expand Down Expand Up @@ -328,6 +337,7 @@ def _match_single_obs(
gtype: GeometryTypes | None,
max_model_gap: float | None,
spatial_method: str | None,
temporal_method: str = "linear",
spatial_tolerance: float,
obs_no_overlap: Literal["ignore", "error", "warn"],
) -> Comparer | None:
Expand All @@ -346,7 +356,10 @@ def _match_single_obs(

raw_mod_data = {
m.name: (
m.extract(observation, spatial_method=spatial_method)
m.extract(
observation,
spatial_method=spatial_method,
)
if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
else m
)
Expand All @@ -358,6 +371,7 @@ def _match_single_obs(
raw_mod_data=raw_mod_data,
max_model_gap=max_model_gap,
obs_no_overlap=obs_no_overlap,
temporal_method=temporal_method,
spatial_tolerance=spatial_tolerance,
)
if matched_data is None:
Expand Down Expand Up @@ -385,6 +399,7 @@ def _match_space_time(
max_model_gap: float | None,
spatial_tolerance: float,
obs_no_overlap: Literal["ignore", "error", "warn"],
temporal_method: str = "linear",
) -> Optional[xr.Dataset]:
idxs = [m.time for m in raw_mod_data.values()]
period = _get_global_start_end(idxs)
Expand All @@ -404,7 +419,9 @@ def _match_space_time(
observation, spatial_tolerance=spatial_tolerance
)
case PointModelResult() as pmr, PointObservation():
aligned = pmr.align(observation, max_gap=max_model_gap)
aligned = pmr.align(
observation, max_gap=max_model_gap, method=temporal_method
)
case _:
raise TypeError(
f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}"
Expand Down
6 changes: 5 additions & 1 deletion src/modelskill/model/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,16 @@ def align(
observation: Observation,
*,
max_gap: float | None = None,
method: str = "linear",
**kwargs: Any,
) -> xr.Dataset:
new_time = observation.time

dati = self.data.dropna("time").interp(
time=new_time, assume_sorted=True, **kwargs
time=new_time,
assume_sorted=True,
method=method, # type: ignore
**kwargs,
)

pmr = PointModelResult(dati)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,26 @@ def test_multiple_models_same_name(tmp_path: Path) -> None:

with pytest.raises(ValueError, match="HKZN_local_2017_DutchCoast"):
ms.match(obs, [mr1, mr2])


def test_directional_data_use_nearest_temporal_interpolation():
mod = ms.PointModelResult(
name="mod",
data=pd.Series(
[359, 5], index=pd.date_range("2023-01-01", periods=2, freq="3H")
),
)

obs = ms.PointObservation(
name="obs",
data=pd.Series(
np.zeros(5), index=pd.date_range("2023-01-01", periods=5, freq="1H")
),
)

cmp = ms.match(
obs=obs,
mod=mod,
temporal_method="nearest",
)
assert cmp.data["mod"].values[1] == pytest.approx(359.0)