Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmark_utils/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def forecast(
for cutoff_idx, cutoff in enumerate(cutoffs):
hist = series[:cutoff] # (T_cutoff, C)
inputs.append(torch.from_numpy(hist))
covariates.append(x.covariates.slice(cutoff, prediction_length))
covariates.append(
x.covariates.slice(series_idx, cutoff, prediction_length)
)
layout.append((series_idx, cutoff_idx))

if not inputs:
Expand Down
66 changes: 43 additions & 23 deletions benchmark_utils/covariates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Covariates payload passed to forecasting adapters.

A small dataclass so the contract is typed and IDE-discoverable. All
three fields default to empty sequences, so datasets without covariates
can just pass ``Covariates()``.
Each field is either ``None`` (that covariate kind is absent for every
series) or a sequence with one entry per series. Datasets without
covariates just pass ``Covariates()`` (all fields ``None``).
"""

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Sequence

import numpy as np
Expand All @@ -15,38 +15,58 @@
class Covariates:
"""Per-series covariates aligned with the ``x`` sequence in ``predict``.

Each field is a sequence whose length equals ``len(x)``. Within a
series, the inner structure depends on the covariate kind — see the
forecasting predict() contract in :mod:`benchmark_utils.adapters.base`.
Each field is either ``None`` (absent for every series) or a sequence
whose length equals ``len(x)`` — one entry per series. The per-series
element shapes are listed below; see the forecasting predict() contract
in :mod:`benchmark_utils.adapters.base`.

Parameters
----------
static_covars
Shape is (channels,)
``None``, or a length-``len(x)`` sequence of arrays of shape (channels,)
hist_covars
Shape is (time, channels)
``None``, or a length-``len(x)`` sequence of arrays of shape (time, channels)
future_covars
Shape is (time, channels)
``None``, or a length-``len(x)`` sequence of arrays of shape (time, channels)
"""

static_covars: Sequence[np.ndarray] = field(default_factory=list)
hist_covars: Sequence[np.ndarray] = field(default_factory=list)
future_covars: Sequence[np.ndarray] = field(default_factory=list)
static_covars: Sequence[np.ndarray] | None = None
hist_covars: Sequence[np.ndarray] | None = None
future_covars: Sequence[np.ndarray] | None = None

def __post_init__(self):
if len(self.static_covars) != len(self.hist_covars) != len(self.future_covars):
# Every provided (non-None) field must cover the same set of series.
lengths = {
len(f)
for f in (self.static_covars, self.hist_covars, self.future_covars)
if f is not None
}
if len(lengths) > 1:
raise ValueError(
"All covariate sequences must have the same length as x"
"All provided covariate sequences must have the same length"
)

def __len__(self) -> int:
# or hist_covars or future_covars, they all have the same length
return len(self.static_covars)
"""Number of series covered (0 if no covariates are present)."""
for f in (self.static_covars, self.hist_covars, self.future_covars):
if f is not None:
return len(f)
return 0

def slice(self, series_idx: int, cutoff: int, horizon: int) -> 'Covariates':
"""Covariates for a single ``(series, cutoff)`` window.

def slice(self, cutoff: int, horizon: int) -> 'Covariates':
"""Get covariates for a single series."""
Selects series ``series_idx`` and slices its time axis: history up to
``cutoff`` and the future window ``[cutoff, cutoff + horizon)``.
"""
return Covariates(
static_covars=self.static_covars,
hist_covars=self.hist_covars[:cutoff],
future_covars=self.future_covars[cutoff:cutoff + horizon],
static_covars=None
if self.static_covars is None
else [self.static_covars[series_idx]],
hist_covars=None
if self.hist_covars is None
else [self.hist_covars[series_idx][:cutoff]],
future_covars=None
if self.future_covars is None
else [self.future_covars[series_idx][cutoff:cutoff + horizon]],
)
17 changes: 12 additions & 5 deletions benchmark_utils/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class ForecastInput:
cutoff_indexes : sequence of sequence of int
Jagged — per-series timestep indexes at which a forecast starts.
covariates : Covariates
Static / historical / future covariates aligned with ``x``.
Defaults to empty.
Static / historical / future covariates aligned with ``x``. Each
covariate field is either ``None`` (absent for every series) or has
one entry per series (length ``len(x)``). Defaults to all-``None``.
"""

x: Sequence[np.ndarray]
Expand All @@ -37,6 +38,12 @@ class ForecastInput:
def __post_init__(self):
if len(self.x) != len(self.cutoff_indexes):
raise ValueError("x and cutoff_indexes must have the same length")
# TODO len(self.covariates) == 0 for some
# if len(self.x) != len(self.covariates):
# raise ValueError("x and covariates must have the same length")
# Covariates either cover all series (one entry each) or are absent.
# len(covariates) is the shared per-field length (0 when absent), so
# covariate-free inputs need no boilerplate.
n_cov = len(self.covariates)
if n_cov not in (0, len(self.x)):
raise ValueError(
f"covariates must cover every series: got length {n_cov}, "
f"expected 0 (absent) or len(x)={len(self.x)}"
)
1 change: 0 additions & 1 deletion datasets/enedis.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def get_data(self):
y_train = [target[first_cut : first_cut + pred_len]]

covariates = Covariates(
static_covars=[],
hist_covars=[hist_covar],
future_covars=[future_covar],
)
Expand Down
6 changes: 3 additions & 3 deletions datasets/monash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
positions in X_test
y_test : List[np.ndarray (n_cutoffs, H, C)]
ground-truth windows
covariates : dict {static_covars, hist_covars,
future_covars} — all empty for
Monash today
covariates : Covariates {static_covars, hist_covars,
future_covars} — all absent
(None) for Monash today
task : "forecasting"
metrics : ["mae", "mse", "rmse", "mase", "smape",
"crps", "wql", "mcis", "pinball", "skill_score_ratio"]
Expand Down
Loading