Skip to content

Commit bdd8a4e

Browse files
authored
Merge pull request #533 from DHI/no_guessing_in_match
Remove parsing of observation and model from match
2 parents 21b85f8 + 737e01c commit bdd8a4e

6 files changed

Lines changed: 79 additions & 258 deletions

File tree

docs/user-guide/overview.qmd

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,6 @@ If not, the [`match()`](`modelskill.match`) function can be used to match the ob
1414

1515
If the observations and model results are not in the same data source (e.g. dfs0 file),
1616
they will need to be defined and then matched in space and time with the `match()` function.
17-
In simple cases, observations and model results can be defined directly in the `match()` function:
18-
19-
```{python}
20-
import modelskill as ms
21-
cmp = ms.match("../data/obs.dfs0", "../data/model.dfs0",
22-
obs_item="obs_WL", mod_item="WL",
23-
gtype='point')
24-
```
25-
26-
But in most cases, the observations and model results will need to be defined separately first.
2717

2818

2919
### Define observations

notebooks/Simple_timeseries_compare.ipynb

Lines changed: 25 additions & 46 deletions
Large diffs are not rendered by default.

src/modelskill/matching.py

Lines changed: 27 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,26 @@
2222

2323
from modelskill.model.point import PointModelResult
2424

25-
from . import Quantity, __version__, model_result
25+
from . import Quantity
2626
from .comparison import Comparer, ComparerCollection
2727
from .model.dfsu import DfsuModelResult
2828
from .model.dummy import DummyModelResult
2929
from .model.grid import GridModelResult
3030
from .model.track import TrackModelResult
31-
from .obs import (
32-
Observation,
33-
PointObservation,
34-
TrackObservation,
35-
observation,
36-
)
31+
from .obs import Observation, PointObservation, TrackObservation
3732
from .timeseries import TimeSeries
3833
from .types import Period
3934

4035
TimeDeltaTypes = Union[float, int, np.timedelta64, pd.Timedelta, timedelta]
4136
IdxOrNameTypes = Optional[Union[int, str]]
4237
GeometryTypes = Optional[Literal["point", "track", "unstructured", "grid"]]
38+
MRTypes = Union[
39+
PointModelResult,
40+
GridModelResult,
41+
DfsuModelResult,
42+
TrackModelResult,
43+
DummyModelResult,
44+
]
4345
MRInputType = Union[
4446
str,
4547
Path,
@@ -52,11 +54,9 @@
5254
xr.Dataset,
5355
xr.DataArray,
5456
TimeSeries,
55-
GridModelResult,
56-
DfsuModelResult,
57-
TrackModelResult,
58-
DummyModelResult,
57+
MRTypes,
5958
]
59+
ObsTypes = Union[PointObservation, TrackObservation]
6060
ObsInputType = Union[
6161
str,
6262
Path,
@@ -65,7 +65,7 @@
6565
mikeio.Dfs0,
6666
pd.DataFrame,
6767
pd.Series,
68-
Observation,
68+
ObsTypes,
6969
]
7070

7171
T = TypeVar("T", bound="TimeSeries")
@@ -173,28 +173,24 @@ def from_matched(
173173

174174
@overload
175175
def match(
176-
obs: Observation,
177-
mod: Union[MRInputType, Sequence[MRInputType]],
176+
obs: ObsTypes,
177+
mod: MRTypes | Sequence[MRTypes],
178178
*,
179-
obs_item: Optional[IdxOrNameTypes] = None,
180-
mod_item: Optional[IdxOrNameTypes] = None,
181-
gtype: Optional[GeometryTypes] = None,
182179
max_model_gap: Optional[float] = None,
183180
spatial_method: Optional[str] = None,
181+
spatial_tolerance: float = 1e-3,
184182
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
185183
) -> Comparer: ...
186184

187185

188186
@overload
189187
def match(
190-
obs: Iterable[Observation],
191-
mod: Union[MRInputType, Sequence[MRInputType]],
188+
obs: Iterable[ObsTypes],
189+
mod: MRTypes | Sequence[MRTypes],
192190
*,
193-
obs_item: Optional[IdxOrNameTypes] = None,
194-
mod_item: Optional[IdxOrNameTypes] = None,
195-
gtype: Optional[GeometryTypes] = None,
196191
max_model_gap: Optional[float] = None,
197192
spatial_method: Optional[str] = None,
193+
spatial_tolerance: float = 1e-3,
198194
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
199195
) -> ComparerCollection: ...
200196

@@ -203,9 +199,6 @@ def match(
203199
obs,
204200
mod,
205201
*,
206-
obs_item=None,
207-
mod_item=None,
208-
gtype=None,
209202
max_model_gap=None,
210203
spatial_method: Optional[str] = None,
211204
spatial_tolerance: float = 1e-3,
@@ -222,17 +215,10 @@ def match(
222215
223216
Parameters
224217
----------
225-
obs : (str, Path, pd.DataFrame, Observation, Sequence[Observation])
218+
obs : (Observation, Sequence[Observation])
226219
Observation(s) to be compared
227-
mod : (str, Path, pd.DataFrame, ModelResult, Sequence[ModelResult])
220+
mod : (ModelResult, Sequence[ModelResult])
228221
Model result(s) to be compared
229-
obs_item : int or str, optional
230-
observation item if obs is a file/dataframe, by default None
231-
mod_item : (int, str), optional
232-
model item if mod is a file/dataframe, by default None
233-
gtype : (str, optional)
234-
Geometry type of the model result (if mod is a file/dataframe).
235-
If not specified, it will be guessed.
236222
max_model_gap : (float, optional)
237223
Maximum time gap (s) in the model result (e.g. for event-based
238224
model results), by default None
@@ -266,9 +252,6 @@ def match(
266252
return _match_single_obs(
267253
obs,
268254
mod,
269-
obs_item=obs_item,
270-
mod_item=mod_item,
271-
gtype=gtype,
272255
max_model_gap=max_model_gap,
273256
spatial_method=spatial_method,
274257
spatial_tolerance=spatial_tolerance,
@@ -303,9 +286,6 @@ def match(
303286
_match_single_obs(
304287
o,
305288
mod,
306-
obs_item=obs_item,
307-
mod_item=mod_item,
308-
gtype=gtype,
309289
max_model_gap=max_model_gap,
310290
spatial_method=spatial_method,
311291
spatial_tolerance=spatial_tolerance,
@@ -320,52 +300,42 @@ def match(
320300

321301

322302
def _match_single_obs(
323-
obs: ObsInputType,
324-
mod: Union[MRInputType, Sequence[MRInputType]],
303+
obs: ObsTypes,
304+
mod: MRTypes | Sequence[MRTypes],
325305
*,
326-
obs_item: int | str | None,
327-
mod_item: int | str | None,
328-
gtype: GeometryTypes | None,
329306
max_model_gap: float | None,
330307
spatial_method: str | None,
331308
spatial_tolerance: float,
332309
obs_no_overlap: Literal["ignore", "error", "warn"],
333310
) -> Comparer | None:
334-
# TODO passing gtype to this function is inconsistent with `match` docstring, where gtype is the geometry type of model result
335-
observation = _parse_single_obs(obs, obs_item, gtype=gtype)
336-
337311
if isinstance(mod, get_args(MRInputType)):
338312
models: list = [mod]
339313
else:
340314
models = mod # type: ignore
341315

342-
model_results = [_parse_single_model(m, item=mod_item, gtype=gtype) for m in models]
343-
names = [m.name for m in model_results]
316+
names = [m.name for m in models]
344317
if len(names) != len(set(names)):
345318
raise ValueError(f"Duplicate model names found: {names}")
346319

347320
raw_mod_data = {
348321
m.name: (
349-
m.extract(observation, spatial_method=spatial_method)
322+
m.extract(obs, spatial_method=spatial_method)
350323
if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
351324
else m
352325
)
353-
for m in model_results
326+
for m in models
354327
}
355328

356329
matched_data = _match_space_time(
357-
observation=observation,
330+
observation=obs,
358331
raw_mod_data=raw_mod_data,
359332
max_model_gap=max_model_gap,
360333
obs_no_overlap=obs_no_overlap,
361334
spatial_tolerance=spatial_tolerance,
362335
)
363336
if matched_data is None:
364337
return None
365-
matched_data.attrs["weight"] = observation.weight
366-
367-
# TODO where does this line belong?
368-
matched_data.attrs["modelskill_version"] = __version__
338+
matched_data.attrs["weight"] = obs.weight
369339

370340
return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data)
371341

@@ -429,66 +399,3 @@ def mo_kind(k: str) -> bool:
429399
data = data.dropna(dim="time", subset=mo_cols)
430400

431401
return data
432-
433-
434-
def _parse_single_obs(
435-
obs: ObsInputType,
436-
obs_item: Optional[int | str],
437-
gtype: Optional[GeometryTypes],
438-
) -> PointObservation | TrackObservation:
439-
if isinstance(obs, (PointObservation, TrackObservation)):
440-
if obs_item is not None:
441-
raise ValueError(
442-
"obs_item argument not allowed if obs is an modelskill.Observation type"
443-
)
444-
return obs
445-
else:
446-
# observation factory can only handle track and point
447-
return observation(obs, item=obs_item, gtype=gtype) # type: ignore
448-
449-
450-
def _parse_single_model(
451-
mod: MRInputType,
452-
item: Optional[IdxOrNameTypes] = None,
453-
gtype: Optional[GeometryTypes] = None,
454-
) -> (
455-
PointModelResult
456-
| TrackModelResult
457-
| GridModelResult
458-
| DfsuModelResult
459-
| DummyModelResult
460-
):
461-
if isinstance(
462-
mod,
463-
(
464-
str,
465-
Path,
466-
pd.DataFrame,
467-
xr.Dataset,
468-
xr.DataArray,
469-
mikeio.Dfs0,
470-
mikeio.Dataset,
471-
mikeio.DataArray,
472-
mikeio.dfsu.Dfsu2DH,
473-
),
474-
):
475-
try:
476-
return model_result(mod, item=item, gtype=gtype)
477-
except ValueError as e:
478-
raise ValueError(
479-
f"Could not compare. Unknown model result type {type(mod)}. {str(e)}"
480-
)
481-
else:
482-
if item is not None:
483-
raise ValueError("item argument not allowed if mod is a ModelResult type")
484-
assert isinstance(
485-
mod,
486-
(
487-
PointModelResult,
488-
TrackModelResult,
489-
GridModelResult,
490-
DfsuModelResult,
491-
DummyModelResult,
492-
),
493-
)
494-
return mod

tests/test_match.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ def test_match_dataarray(o1, o3):
111111

112112
# Using a mikeio.DataArray instead of a Dfs file, makes it possible to select a subset of data
113113

114-
cc = ms.match([o1, o3], da)
114+
cc = ms.match([o1, o3], ms.DfsuModelResult(da))
115115
assert cc.n_models == 1
116116
assert cc["c2"].n_points == 41
117117

118118
da2 = mikeio.read(fn, area=[0, 2, 52, 54], time=slice("2017-10-28 00:00", None))[
119119
0
120120
] # Spatio/temporal subset
121121

122-
cc2 = ms.match([o1, o3], da2)
122+
cc2 = ms.match([o1, o3], ms.DfsuModelResult(da2))
123123
assert cc2["c2"].n_points == 19
124124

125125

@@ -231,22 +231,21 @@ def test_small_multi_model_shifted_time_match():
231231
# observation has four timesteps, but only three of them are in the Simple model and three in the NotSimple model
232232
# the number of overlapping points for all three datasets are 2, but three if we look at the models individually
233233

234-
with pytest.warns(UserWarning):
235-
cmp1 = ms.match(obs=obs, mod=mod)
236-
cmp1 = ms.match(obs=obs, mod=mod)
237-
assert cmp1.n_points == 3
234+
cmp1 = ms.match(obs=ms.PointObservation(obs), mod=ms.PointModelResult(mod))
235+
cmp1 = ms.match(obs=ms.PointObservation(obs), mod=ms.PointModelResult(mod))
236+
assert cmp1.n_points == 3
238237

239-
cmp2 = ms.match(obs=obs, mod=mod2)
240-
assert cmp2.n_points == 3
238+
cmp2 = ms.match(obs=ms.PointObservation(obs), mod=ms.PointModelResult(mod2))
239+
assert cmp2.n_points == 3
241240

242-
mcmp = ms.match(
243-
obs=obs,
244-
mod=[
245-
ms.PointModelResult(mod, name="foo"),
246-
ms.PointModelResult(mod2, name="bar"),
247-
],
248-
)
249-
assert mcmp.n_points == 2
241+
mcmp = ms.match(
242+
obs=ms.PointObservation(obs),
243+
mod=[
244+
ms.PointModelResult(mod, name="foo"),
245+
ms.PointModelResult(mod2, name="bar"),
246+
],
247+
)
248+
assert mcmp.n_points == 2
250249

251250

252251
def test_matched_data_single_model():
@@ -400,7 +399,7 @@ def test_save_comparercollection(o1, o3, tmp_path):
400399
fn = "tests/testdata/SW/HKZN_local_2017_DutchCoast.dfsu"
401400
da = mikeio.read(fn, time=slice("2017-10-28 00:00", None))[0]
402401

403-
cc = ms.match([o1, o3], da)
402+
cc = ms.match([o1, o3], ms.DfsuModelResult(da))
404403

405404
fn = tmp_path / "cc.msk"
406405
cc.save(fn)
@@ -427,13 +426,6 @@ def test_wind_directions():
427426
assert df.loc["obs", "c_rmse"] == pytest.approx(1.322875655532)
428427

429428

430-
def test_specifying_mod_item_not_allowed_twice(o1, mr1):
431-
# item was already specified in the construction of the DfsuModelResult
432-
433-
with pytest.raises(ValueError, match="item"):
434-
ms.match(obs=o1, mod=mr1, mod_item=1)
435-
436-
437429
def test_obs_and_mod_can_not_have_same_aux_item_names():
438430
obs_df = pd.DataFrame(
439431
{"wl": [1.0, 2.0, 3.0], "wind_speed": [1.0, 2.0, 3.0]},

0 commit comments

Comments
 (0)