diff --git a/benchmark_utils/base_solver.py b/benchmark_utils/base_solver.py index cd82ebd..04547e2 100644 --- a/benchmark_utils/base_solver.py +++ b/benchmark_utils/base_solver.py @@ -251,7 +251,9 @@ def forecast( for cutoff_idx, cutoff in enumerate(cutoffs): hist = series[:cutoff] # (T_cutoff, C) inputs.append(torch.from_numpy(hist)) - covariates.append(x.covariates.slice(cutoff, prediction_length)) + covariates.append( + x.covariates.slice(series_idx, cutoff, prediction_length) + ) layout.append((series_idx, cutoff_idx)) if not inputs: diff --git a/benchmark_utils/covariates.py b/benchmark_utils/covariates.py index 15db310..f3e1913 100644 --- a/benchmark_utils/covariates.py +++ b/benchmark_utils/covariates.py @@ -1,11 +1,11 @@ """Covariates payload passed to forecasting adapters. -A small dataclass so the contract is typed and IDE-discoverable. All -three fields default to empty sequences, so datasets without covariates -can just pass ``Covariates()``. +Each field is either ``None`` (that covariate kind is absent for every +series) or a sequence with one entry per series. Datasets without +covariates just pass ``Covariates()`` (all fields ``None``). """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Sequence import numpy as np @@ -15,38 +15,58 @@ class Covariates: """Per-series covariates aligned with the ``x`` sequence in ``predict``. - Each field is a sequence whose length equals ``len(x)``. Within a - series, the inner structure depends on the covariate kind — see the - forecasting predict() contract in :mod:`benchmark_utils.adapters.base`. + Each field is either ``None`` (absent for every series) or a sequence + whose length equals ``len(x)`` — one entry per series. The per-series + element shapes are listed below; see the forecasting predict() contract + in :mod:`benchmark_utils.adapters.base`. Parameters ---------- static_covars - Shape is (channels,) + ``None``, or a length-``len(x)`` sequence of arrays of shape (channels,) hist_covars - Shape is (time, channels) + ``None``, or a length-``len(x)`` sequence of arrays of shape (time, channels) future_covars - Shape is (time, channels) + ``None``, or a length-``len(x)`` sequence of arrays of shape (time, channels) """ - static_covars: Sequence[np.ndarray] = field(default_factory=list) - hist_covars: Sequence[np.ndarray] = field(default_factory=list) - future_covars: Sequence[np.ndarray] = field(default_factory=list) + static_covars: Sequence[np.ndarray] | None = None + hist_covars: Sequence[np.ndarray] | None = None + future_covars: Sequence[np.ndarray] | None = None def __post_init__(self): - if len(self.static_covars) != len(self.hist_covars) != len(self.future_covars): + # Every provided (non-None) field must cover the same set of series. + lengths = { + len(f) + for f in (self.static_covars, self.hist_covars, self.future_covars) + if f is not None + } + if len(lengths) > 1: raise ValueError( - "All covariate sequences must have the same length as x" + "All provided covariate sequences must have the same length" ) - + def __len__(self) -> int: - # or hist_covars or future_covars, they all have the same length - return len(self.static_covars) + """Number of series covered (0 if no covariates are present).""" + for f in (self.static_covars, self.hist_covars, self.future_covars): + if f is not None: + return len(f) + return 0 + + def slice(self, series_idx: int, cutoff: int, horizon: int) -> 'Covariates': + """Covariates for a single ``(series, cutoff)`` window. - def slice(self, cutoff: int, horizon: int) -> 'Covariates': - """Get covariates for a single series.""" + Selects series ``series_idx`` and slices its time axis: history up to + ``cutoff`` and the future window ``[cutoff, cutoff + horizon)``. + """ return Covariates( - static_covars=self.static_covars, - hist_covars=self.hist_covars[:cutoff], - future_covars=self.future_covars[cutoff:cutoff + horizon], + static_covars=None + if self.static_covars is None + else [self.static_covars[series_idx]], + hist_covars=None + if self.hist_covars is None + else [self.hist_covars[series_idx][:cutoff]], + future_covars=None + if self.future_covars is None + else [self.future_covars[series_idx][cutoff:cutoff + horizon]], ) diff --git a/benchmark_utils/inputs.py b/benchmark_utils/inputs.py index 6266207..c0298ec 100644 --- a/benchmark_utils/inputs.py +++ b/benchmark_utils/inputs.py @@ -26,8 +26,9 @@ class ForecastInput: cutoff_indexes : sequence of sequence of int Jagged — per-series timestep indexes at which a forecast starts. covariates : Covariates - Static / historical / future covariates aligned with ``x``. - Defaults to empty. + Static / historical / future covariates aligned with ``x``. Each + covariate field is either ``None`` (absent for every series) or has + one entry per series (length ``len(x)``). Defaults to all-``None``. """ x: Sequence[np.ndarray] @@ -37,6 +38,12 @@ class ForecastInput: def __post_init__(self): if len(self.x) != len(self.cutoff_indexes): raise ValueError("x and cutoff_indexes must have the same length") - # TODO len(self.covariates) == 0 for some - # if len(self.x) != len(self.covariates): - # raise ValueError("x and covariates must have the same length") + # Covariates either cover all series (one entry each) or are absent. + # len(covariates) is the shared per-field length (0 when absent), so + # covariate-free inputs need no boilerplate. + n_cov = len(self.covariates) + if n_cov not in (0, len(self.x)): + raise ValueError( + f"covariates must cover every series: got length {n_cov}, " + f"expected 0 (absent) or len(x)={len(self.x)}" + ) diff --git a/datasets/enedis.py b/datasets/enedis.py index 65695b3..92ef2fd 100644 --- a/datasets/enedis.py +++ b/datasets/enedis.py @@ -158,7 +158,6 @@ def get_data(self): y_train = [target[first_cut : first_cut + pred_len]] covariates = Covariates( - static_covars=[], hist_covars=[hist_covar], future_covars=[future_covar], ) diff --git a/datasets/monash.py b/datasets/monash.py index a607f64..d9e6401 100644 --- a/datasets/monash.py +++ b/datasets/monash.py @@ -20,9 +20,9 @@ positions in X_test y_test : List[np.ndarray (n_cutoffs, H, C)] ground-truth windows -covariates : dict {static_covars, hist_covars, - future_covars} — all empty for - Monash today +covariates : Covariates {static_covars, hist_covars, + future_covars} — all absent + (None) for Monash today task : "forecasting" metrics : ["mae", "mse", "rmse", "mase", "smape", "crps", "wql", "mcis", "pinball", "skill_score_ratio"]