diff --git a/src/modelskill/comparison/_collection.py b/src/modelskill/comparison/_collection.py index 884abe81c..fb91e1969 100644 --- a/src/modelskill/comparison/_collection.py +++ b/src/modelskill/comparison/_collection.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Collection, Sequence from copy import deepcopy import os from pathlib import Path @@ -12,7 +13,6 @@ Union, Optional, Mapping, - Iterable, overload, Hashable, Tuple, @@ -88,7 +88,7 @@ class ComparerCollection(Mapping): plotter = ComparerCollectionPlotter - def __init__(self, comparers: Iterable[Comparer]) -> None: + def __init__(self, comparers: Collection[Comparer]) -> None: self._comparers: Dict[str, Comparer] = {} names = [c.name for c in comparers] @@ -215,13 +215,13 @@ def rename(self, mapping: Dict[str, str]) -> "ComparerCollection": return ComparerCollection(cmps) @overload - def __getitem__(self, x: slice | Iterable[Hashable]) -> ComparerCollection: ... + def __getitem__(self, x: slice | Collection[Hashable]) -> ComparerCollection: ... @overload def __getitem__(self, x: int | Hashable) -> Comparer: ... def __getitem__( - self, x: int | Hashable | slice | Iterable[Hashable] + self, x: int | Hashable | slice | Collection[Hashable] ) -> Comparer | ComparerCollection: if isinstance(x, str): return self._comparers[x] @@ -234,7 +234,7 @@ def __getitem__( name = _get_name(x, self.obs_names) return self._comparers[name] - if isinstance(x, Iterable): + if isinstance(x, Collection): cmps = [self[i] for i in x] return ComparerCollection(cmps) @@ -422,8 +422,8 @@ def query(self, query: str) -> "ComparerCollection": def skill( self, - by: str | Iterable[str] | None = None, - metrics: Iterable[str] | Iterable[Callable] | str | Callable | None = None, + by: str | Collection[str] | None = None, + metrics: Sequence[str] | Sequence[Callable] | str | Callable | None = None, observed: bool = False, ) -> SkillTable: """Aggregated skill assessment of model(s) @@ -492,14 +492,14 @@ def skill( df = cc._to_long_dataframe(attrs_keys=attrs_keys, observed=observed) res = _groupby_df(df, by=agg_cols, metrics=pmetrics) - mtr_cols = [m.__name__ for m in pmetrics] # type: ignore + mtr_cols = [m.__name__ for m in pmetrics] res = res.dropna(subset=mtr_cols, how="all") # TODO: ok to remove empty? res = self._append_xy_to_res(res, cc) - res = cc._add_as_col_if_not_in_index(df, skilldf=res) # type: ignore + res = cc._add_as_col_if_not_in_index(df, skilldf=res) return SkillTable(res) def _to_long_dataframe( - self, attrs_keys: Iterable[str] | None = None, observed: bool = False + self, attrs_keys: Collection[str] | None = None, observed: bool = False ) -> pd.DataFrame: """Return a copy of the data as a long-format pandas DataFrame (for groupby operations)""" frames = [] @@ -570,8 +570,8 @@ def gridded_skill( self, bins: int = 5, binsize: float | None = None, - by: str | Iterable[str] | None = None, - metrics: Iterable[str] | Iterable[Callable] | str | Callable | None = None, + by: str | Collection[str] | None = None, + metrics: Sequence[str] | Sequence[Callable] | str | Callable | None = None, n_min: Optional[int] = None, **kwargs: Any, ) -> SkillGrid: diff --git a/src/modelskill/comparison/_comparer_plotter.py b/src/modelskill/comparison/_comparer_plotter.py index 5467eafc2..8144de67b 100644 --- a/src/modelskill/comparison/_comparer_plotter.py +++ b/src/modelskill/comparison/_comparer_plotter.py @@ -16,7 +16,7 @@ import matplotlib.axes from ._comparison import Comparer -import numpy as np # type: ignore +import numpy as np from .. import metrics as mtr from ..utils import _get_idx @@ -115,7 +115,7 @@ def timeseries( return ax elif backend == "plotly": # pragma: no cover - import plotly.graph_objects as go # type: ignore + import plotly.graph_objects as go mod_scatter_list = [] for j in range(cmp.n_models): @@ -746,7 +746,14 @@ def taylor( df = df.rename(columns={"_std_obs": "obs_std", "_std_mod": "std"}) pts = [ - TaylorPoint(name=r.model, obs_std=r.obs_std, std=r.std, cc=r.cc, marker=marker, marker_size=marker_size) + TaylorPoint( + name=r.model, + obs_std=r.obs_std, + std=r.std, + cc=r.cc, + marker=marker, + marker_size=marker_size, + ) for r in df.itertuples() ] diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index f8fedef75..5b768cad3 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Collection, Sequence from dataclasses import dataclass from pathlib import Path from typing import ( @@ -10,8 +11,6 @@ Mapping, Optional, Union, - Iterable, - Sequence, TYPE_CHECKING, ) import warnings @@ -128,7 +127,7 @@ def _is_model(da: xr.DataArray) -> bool: return str(da.attrs["kind"]) == "model" -def _validate_metrics(metrics: Iterable[Any]) -> None: +def _validate_metrics(metrics: Collection[Any]) -> None: for m in metrics: if isinstance(m, str): if not mtr.is_valid_metric(m): @@ -195,8 +194,8 @@ def all(self) -> Sequence[str]: def parse( items: Sequence[str], obs_item: str | int | None = None, - mod_items: Optional[Iterable[str | int]] = None, - aux_items: Optional[Iterable[str | int]] = None, + mod_items: Collection[str | int] | None = None, + aux_items: Collection[str | int] | None = None, x_item: str | int | None = None, y_item: str | int | None = None, ) -> ItemSelection: @@ -284,18 +283,18 @@ def _area_is_polygon(area: Any) -> bool: def _inside_polygon(polygon: Any, xy: np.ndarray) -> np.ndarray: - import matplotlib.path as mp # type: ignore + import matplotlib.path as mp if polygon.ndim == 1: polygon = np.column_stack((polygon[0::2], polygon[1::2])) - return mp.Path(polygon).contains_points(xy) # type: ignore + return mp.Path(polygon).contains_points(xy) def _matched_data_to_xarray( df: pd.DataFrame, obs_item: int | str | None = None, - mod_items: Optional[Iterable[str | int]] = None, - aux_items: Optional[Iterable[str | int]] = None, + mod_items: Collection[str | int] | None = None, + aux_items: Collection[str | int] | None = None, name: Optional[str] = None, x: Optional[float] = None, y: Optional[float] = None, @@ -464,10 +463,10 @@ def __init__( @staticmethod def from_matched_data( data: xr.Dataset | pd.DataFrame, - raw_mod_data: Optional[Dict[str, PointModelResult | TrackModelResult]] = None, + raw_mod_data: Optional[dict[str, PointModelResult | TrackModelResult]] = None, obs_item: str | int | None = None, - mod_items: Optional[Iterable[str | int]] = None, - aux_items: Optional[Iterable[str | int]] = None, + mod_items: Optional[Collection[str | int]] = None, + aux_items: Optional[Collection[str | int]] = None, name: Optional[str] = None, weight: float = 1.0, x: Optional[float] = None, @@ -477,7 +476,41 @@ def from_matched_data( y_item: str | int | None = None, quantity: Optional[Quantity] = None, ) -> "Comparer": - """Initialize from compared data""" + """Create a Comparer from data that is already matched (aligned). + + Parameters + ---------- + data : [pd.DataFrame, xr.Dataset] + DataFrame (or xarray.Dataset) + raw_mod_data : dict of modelskill.PointModelResult, optional + Raw model data. If None, observation and modeldata must be provided. + obs_item : [str, int], optional + Name or index of observation item, by default first item + mod_items : Collection of [str, int], optional + Names or indicies of model items, if None all remaining columns are model items, by default None + aux_items : Collection of [str, int], optional + Names or indicies of auxiliary items, by default None + name : str, optional + Name of the comparer, by default None (will be set to obs_item) + x : float, optional + x-coordinate of observation, by default None + y : float, optional + y-coordinate of observation, by default None + z : float, optional + z-coordinate of observation, by default None + x_item: [str, int], optional, + Name of x item, only relevant for track data + y_item: [str, int], optional + Name of y item, only relevant for track data + quantity : Quantity, optional + Quantity of the observation and model results, by default Quantity(name="Undefined", unit="Undefined") + + Returns + ------- + Comparer + A Comparer object with matched observation and model data + """ + if not isinstance(data, xr.Dataset): # TODO: handle raw_mod_data by accessing data.attrs["kind"] and only remove nan after data = _matched_data_to_xarray( @@ -818,7 +851,7 @@ def sel( raw_mod_data = {m: raw_mod_data[m] for m in mod_names} if (start is not None) or (end is not None): # TODO: can this be done without to_index? (simplify) - d = d.sel(time=d.time.to_index().to_frame().loc[start:end].index) # type: ignore + d = d.sel(time=d.time.to_index().to_frame().loc[start:end].index) # type: ignore[misc] # Note: if user asks for a specific time, we also filter raw raw_mod_data = { @@ -892,7 +925,7 @@ def query(self, query: str) -> "Comparer": return Comparer.from_matched_data(d, self.raw_mod_data) def _to_long_dataframe( - self, attrs_keys: Iterable[str] | None = None + self, attrs_keys: Collection[str] | None = None ) -> pd.DataFrame: """Return a copy of the data as a long-format pandas DataFrame (for groupby operations)""" @@ -927,8 +960,8 @@ def _to_long_dataframe( def skill( self, - by: str | Iterable[str] | None = None, - metrics: Iterable[str] | Iterable[Callable] | str | Callable | None = None, + by: str | Collection[str] | None = None, + metrics: Sequence[str] | Sequence[Callable] | str | Callable | None = None, ) -> SkillTable: """Skill assessment of model(s) @@ -1051,8 +1084,8 @@ def gridded_skill( self, bins: int = 5, binsize: float | None = None, - by: str | Iterable[str] | None = None, - metrics: Iterable[str] | Iterable[Callable] | str | Callable | None = None, + by: str | Collection[str] | None = None, + metrics: Sequence[str] | Sequence[Callable] | str | Callable | None = None, n_min: int | None = None, **kwargs: Any, ): @@ -1176,12 +1209,12 @@ def remove_bias( for j in range(cmp.n_models): mod_name = cmp.mod_names[j] mod_ts = cmp.raw_mod_data[mod_name] - with xr.set_options(keep_attrs=True): # type: ignore + with xr.set_options(keep_attrs=True): mod_ts.data[mod_name].values = mod_ts.values - bias[j] cmp.data[mod_name].values = cmp.data[mod_name].values - bias[j] elif correct == "Observation": # what if multiple models? - with xr.set_options(keep_attrs=True): # type: ignore + with xr.set_options(keep_attrs=True): cmp.data[cmp._obs_str].values = cmp.data[cmp._obs_str].values + bias else: raise ValueError( diff --git a/src/modelskill/comparison/_utils.py b/src/modelskill/comparison/_utils.py index f8d399d1a..762959a70 100644 --- a/src/modelskill/comparison/_utils.py +++ b/src/modelskill/comparison/_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Callable, Optional, Iterable, List, Tuple, Union +from collections.abc import Collection +from typing import Callable, Optional, List, Tuple, Union from datetime import datetime import numpy as np import pandas as pd @@ -139,7 +140,7 @@ def _add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[s def _parse_groupby( - by: str | Iterable[str] | None, *, n_mod: int, n_qnt: int + by: str | Collection[str] | None, *, n_mod: int, n_qnt: int ) -> List[str | pd.Grouper]: if by is None: cols: List[str | pd.Grouper] @@ -153,7 +154,7 @@ def _parse_groupby( if isinstance(by, str): cols = [by] - elif isinstance(by, Iterable): + elif isinstance(by, Collection): cols = list(by) res = [] diff --git a/src/modelskill/matching.py b/src/modelskill/matching.py index b182f0ad7..fe7aa4d2b 100644 --- a/src/modelskill/matching.py +++ b/src/modelskill/matching.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import ( Collection, - Iterable, Literal, Mapping, Optional, @@ -75,8 +74,8 @@ def from_matched( data: Union[str, Path, pd.DataFrame, mikeio.Dfs0, mikeio.Dataset], *, obs_item: str | int | None = 0, - mod_items: Optional[Iterable[str | int]] = None, - aux_items: Optional[Iterable[str | int]] = None, + mod_items: Optional[Collection[str | int]] = None, + aux_items: Optional[Collection[str | int]] = None, quantity: Optional[Quantity] = None, name: Optional[str] = None, weight: float = 1.0, @@ -95,9 +94,9 @@ def from_matched( with columns obs_item, mod_items, aux_items obs_item : [str, int], optional Name or index of observation item, by default first item - mod_items : Iterable[str, int], optional + mod_items : Collection of [str, int], optional Names or indicies of model items, if None all remaining columns are model items, by default None - aux_items : Iterable[str, int], optional + aux_items : Collection of [str, int], optional Names or indicies of auxiliary items, by default None quantity : Quantity, optional Quantity of the observation and model results, by default Quantity(name="Undefined", unit="Undefined") @@ -185,7 +184,7 @@ def match( @overload def match( - obs: Iterable[ObsTypes], + obs: Sequence[ObsTypes], mod: MRTypes | Sequence[MRTypes], *, max_model_gap: Optional[float] = None, @@ -340,7 +339,7 @@ def _match_single_obs( return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data) -def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period: +def _get_global_start_end(idxs: Collection[pd.DatetimeIndex]) -> Period: assert all([len(x) > 0 for x in idxs]) starts = [x[0] for x in idxs] diff --git a/src/modelskill/metrics.py b/src/modelskill/metrics.py index 30e4b02bc..95a86d30c 100644 --- a/src/modelskill/metrics.py +++ b/src/modelskill/metrics.py @@ -5,10 +5,10 @@ import sys import warnings +from collections.abc import Sequence from typing import ( Any, Callable, - Iterable, List, Literal, Optional, @@ -1164,7 +1164,7 @@ def get_metric(metric: Union[str, Callable]) -> Callable: def _parse_metric( - metric: str | Iterable[str] | Callable | Iterable[Callable] | None, + metric: str | Sequence[str] | Callable | Sequence[Callable] | None, *, directional: bool = False, ) -> List[Callable]: @@ -1179,7 +1179,7 @@ def _parse_metric( metrics: list = [metric] elif callable(metric): metrics = [metric] - elif isinstance(metric, Iterable): + elif isinstance(metric, Sequence): metrics = list(metric) parsed_metrics = [] diff --git a/src/modelskill/model/dfsu.py b/src/modelskill/model/dfsu.py index 511e829d6..c81b32d49 100644 --- a/src/modelskill/model/dfsu.py +++ b/src/modelskill/model/dfsu.py @@ -109,7 +109,7 @@ def time(self) -> pd.DatetimeIndex: return pd.DatetimeIndex(self.data.time) def _in_domain(self, x: float, y: float) -> bool: - return self.data.geometry.contains([x, y]) # type: ignore + return self.data.geometry.contains([x, y]) def extract( self, observation: Observation, spatial_method: Optional[str] = None @@ -194,7 +194,7 @@ def _extract_point( if isinstance(self.data, mikeio.Dataset): ds_model = self.data.isel(element=elemids) else: # Dfsu - ds_model = self.data.read(elements=elemids, items=self.sel_items.all) # type: ignore + ds_model = self.data.read(elements=elemids, items=self.sel_items.all) else: if z is not None: raise NotImplementedError( diff --git a/src/modelskill/plotting/_scatter.py b/src/modelskill/plotting/_scatter.py index de2cd700a..d92aa4837 100644 --- a/src/modelskill/plotting/_scatter.py +++ b/src/modelskill/plotting/_scatter.py @@ -222,7 +222,7 @@ def scatter( df = pd.DataFrame({"obs": x, "model": y}) cmp = from_matched(df) metrics = None if skill_table is True else skill_table - skill = cmp.skill(metrics=metrics) + skill = cmp.skill(metrics=metrics) # type: ignore[arg-type] skill_scores = skill.to_dict("records")[0] return PLOTTING_BACKENDS[backend]( diff --git a/src/modelskill/plotting/_spatial_overview.py b/src/modelskill/plotting/_spatial_overview.py index 156af04c1..2492aced0 100644 --- a/src/modelskill/plotting/_spatial_overview.py +++ b/src/modelskill/plotting/_spatial_overview.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Optional, Iterable, Tuple, TYPE_CHECKING +from collections.abc import Sequence +from typing import Optional, Tuple, TYPE_CHECKING if TYPE_CHECKING: import matplotlib.axes @@ -13,12 +14,12 @@ def spatial_overview( - obs: Observation | Iterable[Observation], + obs: Observation | Sequence[Observation], mod: Optional[ DfsuModelResult | GeometryFM2D - | Iterable[DfsuModelResult] - | Iterable[GeometryFM2D] + | Sequence[DfsuModelResult] + | Sequence[GeometryFM2D] ] = None, ax=None, figsize: Optional[Tuple] = None, @@ -61,8 +62,8 @@ def spatial_overview( ms.plotting.spatial_overview([o1, o2], mr) ``` """ - obs = [] if obs is None else list(obs) if isinstance(obs, Iterable) else [obs] # type: ignore - mods = [] if mod is None else list(mod) if isinstance(mod, Iterable) else [mod] # type: ignore + obs = [] if obs is None else list(obs) if isinstance(obs, Sequence) else [obs] + mods = [] if mod is None else list(mod) if isinstance(mod, Sequence) else [mod] ax = _get_ax(ax=ax, figsize=figsize) @@ -79,7 +80,7 @@ def spatial_overview( g = m # TODO this is not supported for all model types - g.plot.outline(ax=ax) # type: ignore + g.plot.outline(ax=ax) for o in obs: if isinstance(o, PointObservation): @@ -101,7 +102,7 @@ def spatial_overview( for o in obs: if isinstance(o, PointObservation): # TODO adjust xlim to accomodate text - ax.annotate(o.name, (o.x + offset_x, o.y)) # type: ignore + ax.annotate(o.name, (o.x + offset_x, o.y)) if not title: title = "Spatial coverage" diff --git a/src/modelskill/plotting/_wind_rose.py b/src/modelskill/plotting/_wind_rose.py index b10e018bf..01f82d676 100644 --- a/src/modelskill/plotting/_wind_rose.py +++ b/src/modelskill/plotting/_wind_rose.py @@ -116,10 +116,10 @@ def _dirhist2d( mask = data[:, 0] >= vmin calm = len(data[~mask]) / len(data) n = len(data) - counts, _, _ = np.histogram2d( # type: ignore + counts, _, _ = np.histogram2d( data[mask][:, 0], data[mask][:, 1], - bins=[ui, thetai], # type: ignore + bins=[ui, thetai], ) density = counts / n return DirectionalHistogram( @@ -552,7 +552,7 @@ def _add_legend_to_ax( -0.06, 0.1, 0.8, - ) # type: ignore + ) loc = "lower left" else: bbox_to_anchor = (-0.13, -0.06, 0.1, 0.8) diff --git a/src/modelskill/skill.py b/src/modelskill/skill.py index d16bda8dd..0c4010c08 100644 --- a/src/modelskill/skill.py +++ b/src/modelskill/skill.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, Collection, overload, Hashable, TYPE_CHECKING +from collections.abc import Collection, Sequence +from typing import Any, Iterable, overload, Hashable, TYPE_CHECKING import numpy as np import pandas as pd @@ -14,7 +15,7 @@ # TODO remove ? -def _validate_multi_index(index, min_levels=2, max_levels=2): # type: ignore +def _validate_multi_index(index, min_levels=2, max_levels=2): errors = [] if isinstance(index, pd.MultiIndex): if len(index.levels) < min_levels: @@ -220,7 +221,7 @@ def grid( s = self.skillarray ser = s._ser - errors = _validate_multi_index(ser.index) # type: ignore + errors = _validate_multi_index(ser.index) if len(errors) > 0: warnings.warn("plot_grid: " + "\n".join(errors)) # TODO raise error? @@ -282,11 +283,11 @@ def grid( class DeprecatedSkillPlotter: - def __init__(self, skilltable): # type: ignore + def __init__(self, skilltable): self.skilltable = skilltable @staticmethod - def _deprecated_warning(method, field): # type: ignore + def _deprecated_warning(method, field): warnings.warn( f"Selecting metric in plot functions like modelskill.skill().plot.{method}({field}) is deprecated and will be removed in a future version. Use modelskill.skill()['{field}'].plot.{method}() instead.", FutureWarning, @@ -297,20 +298,20 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: "It is not possible to call plot directly on SkillTable! Select metric first (which gives a plotable SkillArray)" ) - def line(self, field, **kwargs): # type: ignore - self._deprecated_warning("line", field) # type: ignore + def line(self, field, **kwargs): + self._deprecated_warning("line", field) return self.skilltable[field].plot.line(**kwargs) - def bar(self, field, **kwargs): # type: ignore - self._deprecated_warning("bar", field) # type: ignore + def bar(self, field, **kwargs): + self._deprecated_warning("bar", field) return self.skilltable[field].plot.bar(**kwargs) - def barh(self, field, **kwargs): # type: ignore - self._deprecated_warning("barh", field) # type: ignore + def barh(self, field, **kwargs): + self._deprecated_warning("barh", field) return self.skilltable[field].plot.barh(**kwargs) - def grid(self, field, **kwargs): # type: ignore - self._deprecated_warning("grid", field) # type: ignore + def grid(self, field, **kwargs): + self._deprecated_warning("grid", field) return self.skilltable[field].plot.grid(**kwargs) @@ -429,7 +430,7 @@ def __init__(self, data: pd.DataFrame): data if isinstance(data, pd.DataFrame) else data.to_dataframe() ) # TODO remove in v1.1 - self.plot = DeprecatedSkillPlotter(self) # type: ignore + self.plot = DeprecatedSkillPlotter(self) # TODO: remove? @property @@ -507,10 +508,10 @@ def _repr_html_(self) -> Any: def __getitem__(self, key: Hashable | int) -> SkillArray: ... @overload - def __getitem__(self, key: Iterable[Hashable]) -> SkillTable: ... + def __getitem__(self, key: Collection[Hashable]) -> SkillTable: ... def __getitem__( - self, key: Hashable | Iterable[Hashable] + self, key: Hashable | Collection[Hashable] ) -> SkillArray | SkillTable: if isinstance(key, int): key = list(self.data.columns)[key] @@ -545,14 +546,14 @@ def __getattr__(self, item: str, *args, **kwargs) -> Any: # ) @property - def iloc(self, *args, **kwargs): # type: ignore + def iloc(self, *args, **kwargs): return self.data.iloc(*args, **kwargs) @property - def loc(self, *args, **kwargs): # type: ignore + def loc(self, *args, **kwargs): return self.data.loc(*args, **kwargs) - def sort_index(self, *args, **kwargs) -> SkillTable: # type: ignore + def sort_index(self, *args, **kwargs) -> SkillTable: """Sort by index (level) e.g. sorting by observation Wrapping pd.DataFrame.sort_index() @@ -570,7 +571,7 @@ def sort_index(self, *args, **kwargs) -> SkillTable: # type: ignore """ return self.__class__(self.data.sort_index(*args, **kwargs)) - def sort_values(self, *args, **kwargs) -> SkillTable: # type: ignore + def sort_values(self, *args, **kwargs) -> SkillTable: """Sort by values e.g. sorting by rmse values Wrapping pd.DataFrame.sort_values() @@ -589,7 +590,7 @@ def sort_values(self, *args, **kwargs) -> SkillTable: # type: ignore """ return self.__class__(self.data.sort_values(*args, **kwargs)) - def swaplevel(self, *args, **kwargs) -> SkillTable: # type: ignore + def swaplevel(self, *args, **kwargs) -> SkillTable: """Swap the levels of the MultiIndex e.g. swapping 'model' and 'observation' Wrapping pd.DataFrame.swaplevel() @@ -765,7 +766,7 @@ def round(self, decimals: int = 3) -> SkillTable: def style( self, decimals: int = 3, - metrics: Iterable[str] | None = None, + metrics: Sequence[str] | None = None, cmap: str = "OrRd", show_best: bool = True, **kwargs: Any, @@ -844,7 +845,7 @@ def style( if len(bg_cols) > 0: sdf = sdf.background_gradient(subset=scols, cmap=cmap) - cmap_r = self._reverse_colormap(cmap) # type: ignore + cmap_r = self._reverse_colormap(cmap) sdf = sdf.background_gradient(subset=lcols, cmap=cmap_r) if show_best: @@ -856,7 +857,7 @@ def style( return sdf - def _reverse_colormap(self, cmap): # type: ignore + def _reverse_colormap(self, cmap): cmap_r = cmap if isinstance(cmap, str): if cmap[-2:] == "_r": @@ -901,14 +902,14 @@ def _style_max(self, s: pd.Series) -> list[str]: # TODO: remove plot_* methods in v1.1; warnings are not needed # as the refering method is also deprecated - def plot_line(self, **kwargs): # type: ignore - return self.plot.line(**kwargs) # type: ignore + def plot_line(self, **kwargs): + return self.plot.line(**kwargs) - def plot_bar(self, **kwargs): # type: ignore - return self.plot.bar(**kwargs) # type: ignore + def plot_bar(self, **kwargs): + return self.plot.bar(**kwargs) - def plot_barh(self, **kwargs): # type: ignore - return self.plot.barh(**kwargs) # type: ignore + def plot_barh(self, **kwargs): + return self.plot.barh(**kwargs) - def plot_grid(self, **kwargs): # type: ignore - return self.plot.grid(**kwargs) # type: ignore + def plot_grid(self, **kwargs): + return self.plot.grid(**kwargs) diff --git a/src/modelskill/skill_grid.py b/src/modelskill/skill_grid.py index 8af6cf07c..4c527a04a 100644 --- a/src/modelskill/skill_grid.py +++ b/src/modelskill/skill_grid.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Any, Iterable, overload, Hashable, TYPE_CHECKING +from collections.abc import Collection +from typing import Any, overload, Hashable, TYPE_CHECKING import xarray as xr if TYPE_CHECKING: @@ -13,12 +14,12 @@ class SkillGridMixin: @property def x(self) -> xr.DataArray: """x-coordinate values""" - return self.data.x # type: ignore + return self.data.x @property def y(self) -> xr.DataArray: """y-coordinate values""" - return self.data.y # type: ignore + return self.data.y @property def coords(self) -> Any: @@ -141,10 +142,10 @@ def __repr__(self) -> str: def __getitem__(self, key: Hashable) -> SkillGridArray: ... @overload - def __getitem__(self, key: Iterable[Hashable]) -> SkillGrid: ... + def __getitem__(self, key: Collection[Hashable]) -> SkillGrid: ... def __getitem__( - self, key: Hashable | Iterable[Hashable] + self, key: Hashable | Collection[Hashable] ) -> SkillGridArray | SkillGrid: result = self.data[key] if isinstance(result, xr.DataArray): diff --git a/src/modelskill/timeseries/_plotter.py b/src/modelskill/timeseries/_plotter.py index e66374a48..391b168c0 100644 --- a/src/modelskill/timeseries/_plotter.py +++ b/src/modelskill/timeseries/_plotter.py @@ -97,7 +97,7 @@ def timeseries(self): Wraps plotly.express.line() function. """ - import plotly.express as px # type: ignore + import plotly.express as px fig = px.line( self._ts._values_as_series, color_discrete_sequence=[self._ts._color] @@ -116,7 +116,7 @@ def hist(self, bins=100, **kwargs): **kwargs other keyword arguments to df.hist() """ - import plotly.express as px # type: ignore + import plotly.express as px fig = px.histogram( self._ts._values_as_series,