Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,10 +804,9 @@ def spatial_overview(
"""
from ..plotting import spatial_overview

obs = [cmp._to_observation() for cmp in self.cc]
# TODO how to add model domain(s)

return spatial_overview(obs, ax=ax, figsize=figsize, title=title)
return spatial_overview(self.cc, ax=ax, figsize=figsize, title=title)

def temporal_coverage(
self,
Expand Down Expand Up @@ -835,11 +834,10 @@ def temporal_coverage(
"""
from ..plotting import temporal_coverage

obs = [cmp._to_observation() for cmp in self.cc]
mod = self.cc[0]._to_model()
mod = list(self.cc[0].raw_mod_data.values())

return temporal_coverage(
obs=obs,
obs=list(self.cc),
mod=mod,
limit_to_model_period=limit_to_model_period,
marker=marker,
Expand Down
32 changes: 0 additions & 32 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .. import metrics as mtr
from .. import Quantity
from ..types import GeometryType
from ..obs import PointObservation, TrackObservation
from ..model import PointModelResult, TrackModelResult
from ..timeseries._timeseries import _validate_data_var_name
from ._comparer_plotter import ComparerPlotter
Expand Down Expand Up @@ -707,37 +706,6 @@ def rename(

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

def _to_observation(self) -> PointObservation | TrackObservation:
"""Convert to Observation"""
if self.gtype == "point":
df = self.data.drop_vars(["x", "y", "z"])[self._obs_str].to_dataframe()
return PointObservation(
data=df,
name=self.name,
x=self.x,
y=self.y,
z=self.z,
quantity=self.quantity,
# TODO: add attrs
)
elif self.gtype == "track":
df = self.data.drop_vars(["z"])[[self._obs_str]].to_dataframe()
return TrackObservation(
data=df,
item=0,
x_item=1,
y_item=2,
name=self.name,
quantity=self.quantity,
# TODO: add attrs
)
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def _to_model(self) -> list[PointModelResult | TrackModelResult]:
mods = list(self.raw_mod_data.values())
return mods

def __add__(self, other):
warnings.warn(
"Merging comparers using + is deprecated, use .merge instead.",
Expand Down
28 changes: 21 additions & 7 deletions src/modelskill/plotting/_spatial_overview.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Iterable, Tuple, TYPE_CHECKING
from typing import Optional, Iterable, Protocol, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
import matplotlib.axes
Expand All @@ -8,12 +8,26 @@

from ..model.point import PointModelResult
from ..model.track import TrackModelResult
from ..obs import Observation, PointObservation, TrackObservation
from ._misc import _get_ax


class SpatialOverviewItem(Protocol):
"""Protocol for items that can be plotted in spatial_overview."""

@property
def gtype(self) -> str: ...
@property
def x(self): ...
@property
def y(self): ...
@property
def name(self) -> str: ...
@property
def n_points(self) -> int: ...


def spatial_overview(
obs: Observation | Iterable[Observation],
obs: SpatialOverviewItem | Iterable[SpatialOverviewItem],
mod: Optional[
DfsuModelResult
| GeometryFM2D
Expand Down Expand Up @@ -82,24 +96,24 @@ def spatial_overview(
g.plot.outline(ax=ax) # type: ignore

for o in obs:
if isinstance(o, PointObservation):
if o.gtype == "point":
ax.scatter(x=o.x, y=o.y, marker="x")
elif isinstance(o, TrackObservation):
elif o.gtype == "track":
if o.n_points < 10000:
ax.scatter(x=o.x, y=o.y, marker=".")
else:
print(f"{o.name}: Too many points to plot")
# TODO: group by lonlat bin or sample randomly
else:
raise ValueError(
f"Could not show observation {o}. Only PointObservation and TrackObservation supported."
f"Could not show {o}. Only point and track geometry types supported."
)

xlim = ax.get_xlim()
offset_x = 0.02 * (xlim[1] - xlim[0])

for o in obs:
if isinstance(o, PointObservation):
if o.gtype == "point":
# TODO adjust xlim to accomodate text
ax.annotate(o.name, (o.x + offset_x, o.y)) # type: ignore

Expand Down