From ceef4f92c5fb373ca427f4eea4b3adf8c49e3132 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Thu, 28 May 2026 16:12:37 +0200 Subject: [PATCH 01/15] adding mitdb to the list of datasets --- .gitignore | 2 + benchmark_utils/download.py | 19 +++++ datasets/mitdb.py | 146 ++++++++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 datasets/mitdb.py diff --git a/.gitignore b/.gitignore index 78df841..d72c924 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ /outputs/ /data/ **/build/ +physionet* +venv* # Cache directory **/__cache__ diff --git a/benchmark_utils/download.py b/benchmark_utils/download.py index becb2c9..0b5a9ce 100644 --- a/benchmark_utils/download.py +++ b/benchmark_utils/download.py @@ -33,6 +33,25 @@ } +def fetch_mitdb() -> Path: + """Return the local directory holding MIT-BIH Arrhythmia Database files. + + Downloads the database via ``wfdb.dl_database`` on first call; subsequent + calls are cache hits if the header files are already present. + + Returns + ------- + Path directory containing ``.hea / .dat / .atr`` files + """ + import wfdb + _MITDB_DIR = Path(__file__).parent.parent / "data" / "mitdb" + + _MITDB_DIR.mkdir(parents=True, exist_ok=True) + if not (_MITDB_DIR / "100.hea").exists(): + wfdb.dl_database("mitdb", dl_dir=str(_MITDB_DIR)) + return _MITDB_DIR + + def fetch_tsb_uad(name: str) -> Path: """Return the local directory holding TSB-UAD's ``.out`` files for *name*. diff --git a/datasets/mitdb.py b/datasets/mitdb.py new file mode 100644 index 0000000..46cc90f --- /dev/null +++ b/datasets/mitdb.py @@ -0,0 +1,146 @@ +"""MIT-BIH Arrhythmia Database — event detection (1-D segmentation). + +Each record is a 2-channel ECG sampled at 360 Hz. Beat annotations are +converted to per-sample integer class labels using the 5-class AAMI grouping: + + 0 background (between beats) + 1 N — Normal / bundle-branch-block / paced + 2 S — Supraventricular ectopic + 3 V — Ventricular ectopic + 4 F — Fusion + 5 Q — Unknown / pacemaker artefact + +Each annotated R-peak is expanded by ±beat_window samples; samples outside +any window are labelled 0 (background). + +Data contract output +-------------------- +X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) +y_train : List[np.ndarray (T_i,)] int labels 0–5 +X_test : List[np.ndarray (T_j, 2)] test portions +y_test : List[np.ndarray (T_j,)] int labels 0–5 +task : "event_detection" +metrics : ["segment_f1", "iou_segment"] +""" + +import numpy as np +from benchopt import BaseDataset + + +# AAMI beat-type grouping (MIT-BIH annotation symbol → class index) +BEAT_CLASS = { + # N group + "N": 1, "L": 1, "R": 1, "e": 1, "j": 1, + # S group + "A": 2, "a": 2, "J": 2, "S": 2, + # V group + "V": 3, "E": 3, + # F group + "F": 4, + # Q group + "P": 5, "f": 5, "u": 5, +} + +# All 48 standard MIT-BIH record IDs +MITDB_RECORDS = [ + "100", "101", "102", "103", "104", "105", "106", "107", + "108", "109", "111", "112", "113", "114", "115", "116", + "117", "118", "119", "121", "122", "123", "124", + "200", "201", "202", "203", "205", "207", "208", "209", + "210", "212", "213", "214", "215", "217", "219", "220", + "221", "222", "223", "228", "230", "231", "232", "233", "234", +] + + +def _load_record(record_id, data_dir): + """Load one WFDB record and return (signal, labels) as numpy arrays. + + Parameters + ---------- + record_id : str e.g. "100" + data_dir : str or Path local directory holding .hea / .dat / .atr files + + Returns + ------- + signal : np.ndarray (T, 2) float32 + labels : np.ndarray (T,) int32 per-sample class 0–5 + """ + raise NotImplementedError + + +def _make_label_array(n_samples, ann_samples, ann_symbols, beat_window): + """Convert beat annotations to a per-sample label array. + + Parameters + ---------- + n_samples : int + ann_samples : np.ndarray (A,) int sample indices of each annotation + ann_symbols : list of str annotation symbols (len A) + beat_window : int half-width of label window in samples + + Returns + ------- + labels : np.ndarray (n_samples,) int32 + """ + raise NotImplementedError + + +class Dataset(BaseDataset): + """MIT-BIH Arrhythmia Database for event detection. + + Parameters + ---------- + record_ids : list of str or "all" + Which records to include. Defaults to the full 48-record set. + debug : bool + If True, use only the first 2 records and truncate to 5 000 samples. + train_ratio : float + Fraction of each record used as training data. + beat_window : int + Half-width (in samples) of the label window around each R-peak. + Default 36 ≈ ±100 ms at 360 Hz (covers the QRS complex). + """ + + name = "MITDB" + + requirements = ["wfdb"] + + parameters = { + "record_ids": ["all"], + "debug": [False], + "train_ratio": [0.7], + "beat_window": [36], + } + + def get_data(self): + from benchmark_utils.download import fetch_mitdb + + data_dir = fetch_mitdb() + + record_ids = MITDB_RECORDS if self.record_ids == "all" else self.record_ids + if self.debug: + record_ids = record_ids[:2] + + X_train, y_train, X_test, y_test = [], [], [], [] + for rid in record_ids: + signal, labels = _load_record(rid, data_dir) + + if self.debug: + signal = signal[:5000] + labels = labels[:5000] + + split = max(1, int(len(signal) * self.train_ratio)) + + X_train.append(signal[:split]) + y_train.append(labels[:split]) + X_test.append(signal[split:]) + y_test.append(labels[split:]) + + return dict( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + task="event_detection", + metrics=["segment_f1", "iou_segment"], + ) From 2850f8e96dbe95dba178c31188006e69d87f58bc Mon Sep 17 00:00:00 2001 From: ABarneche Date: Thu, 28 May 2026 17:15:25 +0200 Subject: [PATCH 02/15] [WIP] Adding event detection task (MITDB Arythmia detection) --- README.rst | 1 + benchmark_utils/metrics.py | 89 +++++++++++++++++++++++++++- datasets/mitdb.py | 117 +++++++++++++++++++++++++------------ objective.py | 20 ++++++- solvers/naive.py | 19 +++++- 5 files changed, 206 insertions(+), 40 deletions(-) diff --git a/README.rst b/README.rst index 8b2aa78..2aae122 100644 --- a/README.rst +++ b/README.rst @@ -9,6 +9,7 @@ The goal is to provide a benchmark that evaluate the models on: - Classification - Forecasting - Anomaly Detection +- Event Detection With diverse modalities (univariate, multivariate, EEG, etc.) and varying sequence lengths. diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index 954d5b1..ae62e3b 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -172,6 +172,88 @@ def _point_adjust(y_true, y_pred): return y_pred_adj +# --------------------------------------------------------------------------- +# Event detection +# --------------------------------------------------------------------------- + +def _iou_1d(s1, w1, s2, w2): + s1, w1, s2, w2 = float(s1), float(w1), float(s2), float(w2) + inter = max(0.0, min(s1 + w1, s2 + w2) - max(s1, s2)) + union = w1 + w2 - inter + return inter / union if union > 0.0 else 0.0 + + +def _ap_from_tp_fp(tp, fp, n_gt): + """Area under the precision-recall step function.""" + if n_gt == 0: + return float("nan") + tp_cum = np.cumsum(tp) + fp_cum = np.cumsum(fp) + recall = tp_cum / n_gt + precision = tp_cum / (tp_cum + fp_cum) + recall = np.concatenate([[0.0], recall]) + precision = np.concatenate([[1.0], precision]) + return float(np.sum((recall[1:] - recall[:-1]) * precision[1:])) + + +def map_iou(y_true, y_pred, iou_threshold=0.5): + """Mean Average Precision at a 1-D IoU threshold for event detection. + + Parameters + ---------- + y_true : list of np.ndarray (N_gt, 2+K) + Ground-truth events per series. Cols: [start_norm, width_norm, *one_hot]. + y_pred : list of np.ndarray (N_pred, 2+K) + Predicted events per series. Cols: [start_norm, width_norm, *class_scores]. + Score for class k is y_pred[i, 2+k]; confidence = per-class score. + iou_threshold : float + Minimum IoU to count a prediction as a true positive (default 0.5). + """ + if not y_true: + return float("nan") + + n_classes = y_true[0].shape[1] - 2 + aps = [] + + for k in range(n_classes): + # Collect GT boxes for class k, grouped by series index + gt_by_series = {} + n_gt = 0 + for i, gt in enumerate(y_true): + boxes = [(row[0], row[1]) for row in gt + if len(gt) > 0 and np.argmax(row[2:]) == k] + gt_by_series[i] = boxes + n_gt += len(boxes) + + # Collect all predictions for class k: (series_idx, start, width, score) + preds = [] + for i, pred in enumerate(y_pred): + for row in pred: + preds.append((i, row[0], row[1], float(row[2 + k]))) + preds.sort(key=lambda x: -x[3]) + + matched = {i: [False] * len(gt_by_series[i]) for i in gt_by_series} + tp = np.zeros(len(preds)) + fp = np.zeros(len(preds)) + + for j, (i, s, w, _) in enumerate(preds): + best_iou, best_gi = 0.0, -1 + for gi, (gs, gw) in enumerate(gt_by_series.get(i, [])): + iou = _iou_1d(s, w, gs, gw) + if iou > best_iou: + best_iou, best_gi = iou, gi + if best_iou >= iou_threshold and best_gi >= 0 and not matched[i][best_gi]: + tp[j] = 1.0 + matched[i][best_gi] = True + else: + fp[j] = 1.0 + + aps.append(_ap_from_tp_fp(tp, fp, n_gt)) + + valid = [ap for ap in aps if not np.isnan(ap)] + return float(np.mean(valid)) if valid else float("nan") + + # --------------------------------------------------------------------------- # Registry: maps metric name → function # --------------------------------------------------------------------------- @@ -196,4 +278,9 @@ def _point_adjust(y_true, y_pred): "f1_pa": f1_pa, } -ALL_METRICS = {**FORECASTING_METRICS, **CLASSIFICATION_METRICS, **AD_METRICS} +EVENT_METRICS = { + "map_iou": map_iou, +} + +ALL_METRICS = {**FORECASTING_METRICS, **CLASSIFICATION_METRICS, **AD_METRICS, + **EVENT_METRICS} diff --git a/datasets/mitdb.py b/datasets/mitdb.py index 46cc90f..f8a06f5 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -1,30 +1,31 @@ -"""MIT-BIH Arrhythmia Database — event detection (1-D segmentation). +"""MIT-BIH Arrhythmia Database — 1-D event detection. Each record is a 2-channel ECG sampled at 360 Hz. Beat annotations are -converted to per-sample integer class labels using the 5-class AAMI grouping: +converted to an object-detection style target using the 5-class AAMI grouping: - 0 background (between beats) - 1 N — Normal / bundle-branch-block / paced - 2 S — Supraventricular ectopic - 3 V — Ventricular ectopic - 4 F — Fusion - 5 Q — Unknown / pacemaker artefact - -Each annotated R-peak is expanded by ±beat_window samples; samples outside -any window are labelled 0 (background). + class 0 N — Normal / bundle-branch-block / paced + class 1 S — Supraventricular ectopic + class 2 V — Ventricular ectopic + class 3 F — Fusion + class 4 Q — Unknown / pacemaker artefact Data contract output -------------------- -X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) -y_train : List[np.ndarray (T_i,)] int labels 0–5 -X_test : List[np.ndarray (T_j, 2)] test portions -y_test : List[np.ndarray (T_j,)] int labels 0–5 +X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) +y_train : List[np.ndarray (N_i, 2+K)] one row per beat event: + col 0 start (normalised, 0–1) + col 1 width (normalised, 0–1) + cols 2… one-hot class vector (K) +X_test : List[np.ndarray (T_j, 2)] test portions +y_test : List[np.ndarray (N_j, 2+K)] same format task : "event_detection" -metrics : ["segment_f1", "iou_segment"] +metrics : ["map_iou"] +extra : n_classes (int) K above """ import numpy as np from benchopt import BaseDataset +from benchmark_utils.download import fetch_mitdb # AAMI beat-type grouping (MIT-BIH annotation symbol → class index) @@ -62,31 +63,61 @@ def _load_record(record_id, data_dir): Returns ------- - signal : np.ndarray (T, 2) float32 - labels : np.ndarray (T,) int32 per-sample class 0–5 + signal : np.ndarray (T, 2) float32 + ann_samples: np.ndarray (A,) int32 R-peak sample indices + ann_symbols: list of str length A annotation symbols """ - raise NotImplementedError + import wfdb + + path = str(data_dir / record_id) + record = wfdb.rdrecord(path) + ann = wfdb.rdann(path, "atr") + signal = record.p_signal.astype(np.float32) + return signal, ann.sample, ann.symbol -def _make_label_array(n_samples, ann_samples, ann_symbols, beat_window): - """Convert beat annotations to a per-sample label array. + +def _annotations_to_events(n_samples, ann_samples, ann_symbols, beat_window, + n_classes): + """Convert beat annotations to an object-detection target array. Parameters ---------- - n_samples : int + n_samples : int total length of the series ann_samples : np.ndarray (A,) int sample indices of each annotation ann_symbols : list of str annotation symbols (len A) - beat_window : int half-width of label window in samples + beat_window : int half-width of each event in samples + n_classes : int K — number of AAMI classes Returns ------- - labels : np.ndarray (n_samples,) int32 + events : np.ndarray (N, 2+K) float32 + Each row: [start_norm, width_norm, *one_hot_class] + Only beats whose symbol appears in BEAT_CLASS are included. """ - raise NotImplementedError + rows = [] + for sample, symbol in zip(ann_samples, ann_symbols): + aami_class = BEAT_CLASS.get(symbol) + if aami_class is None: + continue + # Collapse to single class when n_classes == 1 + class_idx = 0 if n_classes == 1 else aami_class - 1 + if class_idx >= n_classes: + continue + + start = max(0, sample - beat_window) + end = min(n_samples, sample + beat_window) + one_hot = np.zeros(n_classes, dtype=np.float32) + one_hot[class_idx] = 1.0 + rows.append([start / n_samples, (end - start) / n_samples, *one_hot]) + + if not rows: + return np.zeros((0, 2 + n_classes), dtype=np.float32) + return np.array(rows, dtype=np.float32) class Dataset(BaseDataset): - """MIT-BIH Arrhythmia Database for event detection. + """MIT-BIH Arrhythmia Database for 1-D event detection. Parameters ---------- @@ -97,8 +128,12 @@ class Dataset(BaseDataset): train_ratio : float Fraction of each record used as training data. beat_window : int - Half-width (in samples) of the label window around each R-peak. + Half-width (in samples) of each event box around the R-peak. Default 36 ≈ ±100 ms at 360 Hz (covers the QRS complex). + n_classes : int + K — number of AAMI beat classes to distinguish (1–5). + Classes are ordered N, S, V, F, Q; setting n_classes=1 collapses all + annotated beats into a single "beat" class. """ name = "MITDB" @@ -110,11 +145,10 @@ class Dataset(BaseDataset): "debug": [False], "train_ratio": [0.7], "beat_window": [36], + "n_classes": [5], } def get_data(self): - from benchmark_utils.download import fetch_mitdb - data_dir = fetch_mitdb() record_ids = MITDB_RECORDS if self.record_ids == "all" else self.record_ids @@ -123,18 +157,28 @@ def get_data(self): X_train, y_train, X_test, y_test = [], [], [], [] for rid in record_ids: - signal, labels = _load_record(rid, data_dir) + signal, ann_samples, ann_symbols = _load_record(rid, data_dir) if self.debug: + mask = ann_samples < 5000 + ann_samples = ann_samples[mask] + ann_symbols = [s for s, m in zip(ann_symbols, mask) if m] signal = signal[:5000] - labels = labels[:5000] split = max(1, int(len(signal) * self.train_ratio)) - X_train.append(signal[:split]) - y_train.append(labels[:split]) - X_test.append(signal[split:]) - y_test.append(labels[split:]) + for seg_signal, start, end, Xl, yl in [ + (signal[:split], 0, split, X_train, y_train), + (signal[split:], split, len(signal), X_test, y_test), + ]: + seg_ann = ann_samples[(ann_samples >= start) & (ann_samples < end)] - start + seg_sym = [s for s, idx in zip(ann_symbols, ann_samples) + if start <= idx < end] + Xl.append(seg_signal) + yl.append(_annotations_to_events( + len(seg_signal), seg_ann, seg_sym, + self.beat_window, self.n_classes, + )) return dict( X_train=X_train, @@ -142,5 +186,6 @@ def get_data(self): X_test=X_test, y_test=y_test, task="event_detection", - metrics=["segment_f1", "iou_segment"], + metrics=["map_iou"], + n_classes=self.n_classes, ) diff --git a/objective.py b/objective.py index 61332cf..e51b4ab 100644 --- a/objective.py +++ b/objective.py @@ -13,7 +13,7 @@ X_test : List[np.ndarray (T_j, C)] test contexts / series y_test : array-like task-specific (see below) task : str one of {"forecasting", "classification", - "anomaly_detection"} + "anomaly_detection", "event_detection"} metrics : List[str] names from benchmark_utils.metrics.ALL_METRICS Task-specific shapes @@ -26,6 +26,9 @@ extra n_classes (int) anomaly_detection y_train None y_test List[(T_j,)] int point-level binary labels +event_detection y_train List[(N_i, 2+K)] float object-detection boxes + y_test List[(N_j, 2+K)] float object-detection boxes + extra n_classes (int) Solver contract --------------- @@ -92,6 +95,8 @@ def evaluate_result(self, model): return self._eval_classification(model) elif self.task == "anomaly_detection": return self._eval_anomaly_detection(model) + elif self.task == "event_detection": + return self._eval_event_detection(model) else: raise ValueError(f"Unknown task: {self.task!r}") @@ -128,6 +133,17 @@ def _eval_classification(self, model): result[name] = ALL_METRICS[name](y_true, y_pred) return result + # --- event detection ----------------------------------------------- + + def _eval_event_detection(self, model): + # model.predict returns (N, 2+K) float array per series + preds = [np.asarray(model.predict(x)) for x in self.X_test] + + result = {} + for name in self.metrics: + result[name] = ALL_METRICS[name](self.y_test, preds) + return result + # --- anomaly detection --------------------------------------------- def _eval_anomaly_detection(self, model): @@ -162,5 +178,7 @@ def predict(self, x): return 0 elif self._task == "anomaly_detection": return np.zeros(x.shape[0]) + elif self._task == "event_detection": + return np.zeros((0, 2 + self._meta.get("n_classes", 1))) return {"model": _ConstantAdapter(self.task, self.meta, self.X_test)} diff --git a/solvers/naive.py b/solvers/naive.py index be8cdcd..448e855 100644 --- a/solvers/naive.py +++ b/solvers/naive.py @@ -1,8 +1,9 @@ -"""Naive baseline solver — works for all three tasks. +"""Naive baseline solver — works for all four tasks. Forecasting : seasonal naive (repeat last season) Classification : most-frequent-class in training set Anomaly detection: constant zero scores (everything is normal) +Event detection : predict no events (empty box array) This solver has no model dependencies and should always pass ``benchopt test``. It also serves as a reference for the expected solver structure. @@ -57,6 +58,16 @@ def predict(self, x: np.ndarray) -> np.ndarray: return np.zeros(x.shape[0], dtype=np.float32) +class _NoEventPredictor(BaseTSFMAdapter): + """Predict no events — returns an empty (0, 2+K) box array.""" + + def __init__(self, n_classes): + self._n_classes = n_classes + + def predict(self, x: np.ndarray) -> np.ndarray: + return np.zeros((0, 2 + self._n_classes), dtype=np.float32) + + # --------------------------------------------------------------------------- # Solver # --------------------------------------------------------------------------- @@ -79,7 +90,8 @@ class Solver(BaseSolver): "seasonality": [1], } - SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection"} + SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection", + "event_detection"} def skip(self, task, **kwargs): if task not in self.SUPPORTED_TASKS: @@ -104,5 +116,8 @@ def run(self, _): elif self.task == "anomaly_detection": self._adapter = _ConstantScorer() + elif self.task == "event_detection": + self._adapter = _NoEventPredictor(self.meta.get("n_classes", 1)) + def get_result(self): return {"model": self._adapter} From 94b3ed597fc9ef40059be1739c848dc344eca54a Mon Sep 17 00:00:00 2001 From: Tianjun Hou Date: Fri, 29 May 2026 11:18:56 +0200 Subject: [PATCH 03/15] ENH add Chronos event-detection solver with EventHead adapter - Add EventHead (Transformer-decoder, 10 learned queries) and ChronosEventAdapter in benchmark_utils/adapters/event_detection.py - Add Chronos-EventDetection solver in solvers/chronos_event.py: frozen Chronos encoder + trainable head, with model_path param for offline/local model loading - Add event_detection task branch to objective.py - Add event_span_iou and event_class_f1 metrics to metrics.py - Fix chronos.py: defer ChronosPipeline import to set_objective to avoid import-time crash when chronos-forecasting is not installed Co-authored-by: Cursor --- benchmark_utils/adapters/__init__.py | 3 + benchmark_utils/adapters/event_detection.py | 276 +++++++++++++++++ benchmark_utils/metrics.py | 155 +++++++++- objective.py | 43 ++- solvers/chronos.py | 8 +- solvers/chronos_event.py | 326 ++++++++++++++++++++ 6 files changed, 789 insertions(+), 22 deletions(-) create mode 100644 benchmark_utils/adapters/event_detection.py create mode 100644 solvers/chronos_event.py diff --git a/benchmark_utils/adapters/__init__.py b/benchmark_utils/adapters/__init__.py index c5d9f32..6039249 100644 --- a/benchmark_utils/adapters/__init__.py +++ b/benchmark_utils/adapters/__init__.py @@ -9,6 +9,7 @@ ) from .linear_probe import LinearProbeAdapter from .forecast_residual import ForecastResidualAdapter +from .event_detection import EventHead, ChronosEventAdapter __all__ = [ "BaseTSFMAdapter", @@ -20,4 +21,6 @@ "Encoder", "LinearProbeAdapter", "ForecastResidualAdapter", + "EventHead", + "ChronosEventAdapter", ] diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py new file mode 100644 index 0000000..ca97af0 --- /dev/null +++ b/benchmark_utils/adapters/event_detection.py @@ -0,0 +1,276 @@ +"""Event-detection adapter for a frozen Chronos encoder + trainable head. + +Architecture +------------ +x (T, C) + -> Chronos.embed per channel (T_tok, D) x C + -> mean-pool across channels (T_tok, D) [memory key/value] + -> EventHead Transformer-decoder [10 learned queries] + -> pos_head : Linear(D,2) -> sigmoid (start, length) in [0,1] + -> cls_head : Linear(D,k) -> logits k binary class logits + +Output shape per series: (N=10, 2+k) + +Channel handling +---------------- +Chronos is strictly univariate. We embed every channel independently and then +**mean-pool the per-channel (T_tok, D) tensors**. Alternatives are possible: + + * concat along D: increases D by C-fold, requires retraining head per C + * attention pool: add a learned aggregation layer (more parameters) + +Mean-pool is simple, parameter-free, and works for any C at inference time. + +Target format +------------- +y per series: (N=10, 2+k) float32, all-zero rows = empty / no-event slots. + col 0 : start (normalised to [0,1] over T=512) + col 1 : length (normalised to [0,1] over T=512) + col 2..2+k : binary multi-class one-hot columns (sum >= 1 for real events) + +Loss +---- +For each of the 10 query slots: + * has_event mask = (y_cls.sum(-1) > 0) shape (B, N) + * position loss : smooth_l1 on (start, length), applied only where has_event + * class loss : BCEWithLogitsLoss on all k columns, applied to all slots + (empty slots drive logits toward 0, which is the no-event baseline) + +Combined: loss = pos_loss + lambda_cls * cls_loss (lambda_cls=1.0 by default) +""" + +import numpy as np +import torch +import torch.nn as nn + +from .base import BaseTSFMAdapter + + +# --------------------------------------------------------------------------- +# EventHead +# --------------------------------------------------------------------------- + +class EventHead(nn.Module): + """Transformer-decoder that turns Chronos embeddings into span predictions. + + Parameters + ---------- + d_model : int + Embedding dimension coming out of the Chronos encoder. + n_classes : int + Number of binary event-class columns k. + num_queries : int + Fixed number of event slots N (default 10). + num_decoder_layers : int + Depth of the Transformer decoder (default 2). + nhead : int + Number of attention heads (default 8 when d_model >= 512). + dim_feedforward : int + FFN inner dimension (default 4 * d_model). + dropout : float + Dropout rate in the decoder (default 0.1). + """ + + def __init__( + self, + d_model: int, + n_classes: int, + num_queries: int = 10, + num_decoder_layers: int = 2, + nhead: int = 8, + dim_feedforward: int | None = None, + dropout: float = 0.1, + ): + super().__init__() + self.d_model = d_model + self.n_classes = n_classes + self.num_queries = num_queries + + if dim_feedforward is None: + dim_feedforward = 4 * d_model + + # Ensure nhead divides d_model cleanly + while d_model % nhead != 0 and nhead > 1: + nhead //= 2 + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + batch_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer, + num_layers=num_decoder_layers, + ) + + # N learnable query embeddings — one per event slot + self.query_embed = nn.Embedding(num_queries, d_model) + + # Output heads + self.pos_head = nn.Sequential( + nn.Linear(d_model, d_model // 2), + nn.ReLU(), + nn.Linear(d_model // 2, 2), + ) + self.cls_head = nn.Sequential( + nn.Linear(d_model, d_model // 2), + nn.ReLU(), + nn.Linear(d_model // 2, n_classes), + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + nn.init.normal_(self.query_embed.weight, std=0.02) + + def forward(self, memory: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Parameters + ---------- + memory : (B, T_tok, D) — Chronos encoder embeddings (mean-pooled over C) + + Returns + ------- + pos_logits : (B, N, 2) — raw (pre-sigmoid) span predictions + cls_logits : (B, N, k) — raw (pre-sigmoid) class logits + """ + B = memory.size(0) + # Expand queries across batch + queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1) # (B, N, D) + + decoded = self.decoder(tgt=queries, memory=memory) # (B, N, D) + + pos_logits = self.pos_head(decoded) # (B, N, 2) + cls_logits = self.cls_head(decoded) # (B, N, k) + return pos_logits, cls_logits + + def compute_loss( + self, + pos_logits: torch.Tensor, + cls_logits: torch.Tensor, + y: torch.Tensor, + lambda_cls: float = 1.0, + ) -> torch.Tensor: + """Compute combined position + classification loss. + + Parameters + ---------- + pos_logits : (B, N, 2) — raw span predictions (sigmoid applied inside) + cls_logits : (B, N, k) — raw class logits + y : (B, N, 2+k) — ground-truth targets, float32 + lambda_cls : float — weight for the class loss term + + Returns + ------- + scalar loss tensor + """ + y_pos = y[..., :2] # (B, N, 2) start, length + y_cls = y[..., 2:] # (B, N, k) binary class targets + + # Mask: only penalise position loss on slots that have a real event + has_event = (y_cls.sum(dim=-1) > 0).float() # (B, N) + + pos_pred = torch.sigmoid(pos_logits) # (B, N, 2) in [0,1] + pos_loss_per = nn.functional.smooth_l1_loss( + pos_pred, y_pos, reduction="none" + ).mean(dim=-1) # (B, N) + pos_loss = (pos_loss_per * has_event).sum() / (has_event.sum() + 1e-6) + + cls_loss = nn.functional.binary_cross_entropy_with_logits( + cls_logits, y_cls, reduction="mean" + ) + + return pos_loss + lambda_cls * cls_loss + + +# --------------------------------------------------------------------------- +# ChronosEventAdapter +# --------------------------------------------------------------------------- + +class ChronosEventAdapter(BaseTSFMAdapter): + """Fitted adapter: frozen Chronos encoder + trained EventHead. + + Parameters + ---------- + pipeline : ChronosPipeline + A loaded (and frozen) Chronos pipeline instance. + head : EventHead + A trained EventHead instance on the target device. + device : str + torch device string, e.g. "cuda" or "cpu". + n_classes : int + Number of binary event-class columns k. + T : int + Input series length (used only for documentation; not enforced here). + """ + + def __init__(self, pipeline, head: EventHead, device: str, + n_classes: int, T: int = 512): + self.pipeline = pipeline + self.head = head + self.device = device + self.n_classes = n_classes + self.T = T + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _embed_series(self, x: np.ndarray) -> torch.Tensor: + """Embed a single multivariate series via Chronos, mean-pool over C. + + Parameters + ---------- + x : (T, C) float array + + Returns + ------- + Tensor of shape (T_tok, D) on CPU, float32 + """ + import torch as _torch + C = x.shape[1] + channel_embs = [] + for c in range(C): + ctx = _torch.tensor(x[:, c], dtype=_torch.float32) + # pipeline.embed returns (1, T_tok, D), tokenizer_state + emb, _ = self.pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) + channel_embs.append(emb.squeeze(0).float().cpu()) # (T_tok, D) + # Stack and mean-pool over channels + stacked = _torch.stack(channel_embs, dim=0) # (C, T_tok, D) + return stacked.mean(dim=0) # (T_tok, D) + + # ------------------------------------------------------------------ + # BaseTSFMAdapter interface + # ------------------------------------------------------------------ + + def predict(self, x: np.ndarray) -> np.ndarray: + """Run inference on a single multivariate series. + + Parameters + ---------- + x : (T, C) float array + + Returns + ------- + spans : (N=10, 2+k) float32 array + Columns 0-1 : start and length in [0,1] + Columns 2.. : binary class probabilities in [0,1] + """ + self.head.eval() + memory = self._embed_series(x) # (T_tok, D) cpu + memory = memory.unsqueeze(0).to(self.device) # (1, T_tok, D) + + with torch.no_grad(): + pos_logits, cls_logits = self.head(memory) # (1,N,2), (1,N,k) + + pos = torch.sigmoid(pos_logits[0]).cpu().numpy() # (N, 2) + cls = torch.sigmoid(cls_logits[0]).cpu().numpy() # (N, k) + return np.concatenate([pos, cls], axis=-1).astype(np.float32) # (N, 2+k) diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index ae62e3b..2eb010b 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -254,6 +254,151 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): return float(np.mean(valid)) if valid else float("nan") +# _span_iou is an alias kept consistent with _iou_1d above +def _span_iou(start_a, len_a, start_b, len_b): + """Intersection-over-Union for two [start, start+length) intervals.""" + return _iou_1d(start_a, len_a, start_b, len_b) + + +def event_span_iou(y_true, y_pred, iou_threshold=0.5): + """Mean span IoU for event detection with greedy matching. + + Computes precision/recall/F1 of event spans across all series using greedy + IoU matching. A predicted span is a true positive if its IoU with an + unmatched ground-truth span exceeds ``iou_threshold``. + + Parameters + ---------- + y_true : List[np.ndarray (N=10, 2+k)] + Ground-truth padded event targets. All-zero rows = empty slots. + y_pred : List[np.ndarray (N=10, 2+k)] + Predicted outputs. Positions (cols 0-1) in [0,1]; class probs in [0,1]. + iou_threshold : float + Minimum IoU to count as a correct span detection (default 0.5). + + Returns + ------- + float — span F1 score averaged over all series + """ + f1_scores = [] + for gt, pr in zip(y_true, y_pred): + gt = np.asarray(gt) + pr = np.asarray(pr) + + gt_mask = gt[:, 2:].sum(axis=1) > 0 + pr_mask = pr[:, 2:].max(axis=1) > 0.5 + + gt_spans = gt[gt_mask, :2] + pr_spans = pr[pr_mask, :2] + + G = gt_spans.shape[0] + P = pr_spans.shape[0] + + if G == 0 and P == 0: + f1_scores.append(1.0) + continue + if G == 0 or P == 0: + f1_scores.append(0.0) + continue + + matched_gt = set() + tp = 0 + for pi in range(P): + best_iou = 0.0 + best_gi = -1 + for gi in range(G): + if gi in matched_gt: + continue + iou = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + if iou > best_iou: + best_iou = iou + best_gi = gi + if best_iou >= iou_threshold and best_gi >= 0: + matched_gt.add(best_gi) + tp += 1 + + precision = tp / P if P > 0 else 0.0 + recall = tp / G if G > 0 else 0.0 + f1 = (2 * precision * recall / (precision + recall) + if precision + recall > 0 else 0.0) + f1_scores.append(f1) + + return float(np.mean(f1_scores)) + + +def event_class_f1(y_true, y_pred, iou_threshold=0.5): + """Micro-F1 over binary class columns on IoU-matched event slots. + + For each predicted span that is IoU-matched to a ground-truth span, we + threshold class probabilities at 0.5 and compute micro-F1 over the k + binary class columns. Unmatched ground-truth spans count as false negatives + for all their active classes. + + Parameters + ---------- + y_true : List[np.ndarray (N=10, 2+k)] + y_pred : List[np.ndarray (N=10, 2+k)] + iou_threshold : float + + Returns + ------- + float — micro-F1 over class columns + """ + tp_total = fp_total = fn_total = 0 + + for gt, pr in zip(y_true, y_pred): + gt = np.asarray(gt) + pr = np.asarray(pr) + + gt_mask = gt[:, 2:].sum(axis=1) > 0 + pr_mask = pr[:, 2:].max(axis=1) > 0.5 + + gt_spans = gt[gt_mask] + pr_spans = pr[pr_mask] + + G = gt_spans.shape[0] + P = pr_spans.shape[0] + + matched_gt = {} + matched_pr = set() + for gi in range(G): + best_iou = 0.0 + best_pi = -1 + for pi in range(P): + if pi in matched_pr: + continue + iou = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + if iou > best_iou: + best_iou = iou + best_pi = pi + if best_iou >= iou_threshold and best_pi >= 0: + matched_gt[gi] = best_pi + matched_pr.add(best_pi) + + for gi, pi in matched_gt.items(): + gt_cls = (gt_spans[gi, 2:] > 0.5).astype(int) + pr_cls = (pr_spans[pi, 2:] > 0.5).astype(int) + tp_total += int((gt_cls & pr_cls).sum()) + fp_total += int(((1 - gt_cls) & pr_cls).sum()) + fn_total += int((gt_cls & (1 - pr_cls)).sum()) + + for gi in range(G): + if gi not in matched_gt: + fn_total += int((gt_spans[gi, 2:] > 0.5).sum()) + + precision = tp_total / (tp_total + fp_total) if (tp_total + fp_total) > 0 else 0.0 + recall = tp_total / (tp_total + fn_total) if (tp_total + fn_total) > 0 else 0.0 + if precision + recall > 0: + return float(2 * precision * recall / (precision + recall)) + return 0.0 + + # --------------------------------------------------------------------------- # Registry: maps metric name → function # --------------------------------------------------------------------------- @@ -280,7 +425,13 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): EVENT_METRICS = { "map_iou": map_iou, + "event_span_iou": event_span_iou, + "event_class_f1": event_class_f1, } -ALL_METRICS = {**FORECASTING_METRICS, **CLASSIFICATION_METRICS, **AD_METRICS, - **EVENT_METRICS} +ALL_METRICS = { + **FORECASTING_METRICS, + **CLASSIFICATION_METRICS, + **AD_METRICS, + **EVENT_METRICS, +} diff --git a/objective.py b/objective.py index 5e3a4c6..0f0474a 100644 --- a/objective.py +++ b/objective.py @@ -1,8 +1,8 @@ """ Unified objective for the TSFM benchmark. -Supports three tasks — forecasting, classification, anomaly detection — -dispatched via the ``task`` field provided by each dataset. +Supports four tasks — forecasting, classification, anomaly detection, and +event detection — dispatched via the ``task`` field provided by each dataset. Data contract ------------- @@ -35,9 +35,14 @@ extra n_classes (int) anomaly_detection y_train None y_test List[(T_j,)] int point-level binary labels -event_detection y_train List[(N_i, 2+K)] float object-detection boxes - y_test List[(N_j, 2+K)] float object-detection boxes - extra n_classes (int) +event_detection y_train List[(N, 2+K)] float object-detection boxes + y_test List[(N, 2+K)] float object-detection boxes + extra n_classes (int), T (int, default 512) + + Each row of the (N, 2+k) array is one event slot: + col 0 : start, normalised to [0, 1] over T + col 1 : length, normalised to [0, 1] over T + col 2.. : k binary class columns (all-zero row = empty slot) Solver contract --------------- @@ -159,26 +164,34 @@ def _eval_classification(self, model): result[name] = ALL_METRICS[name](y_true, y_pred) return result - # --- event detection ----------------------------------------------- + # --- anomaly detection --------------------------------------------- - def _eval_event_detection(self, model): - # model.predict returns (N, 2+K) float array per series - preds = [np.asarray(model.predict(x)) for x in self.X_test] + def _eval_anomaly_detection(self, model): + # model.predict returns (T_j,) float scores per series + scores = [np.asarray(model.predict(x)) for x in self.X_test] result = {} for name in self.metrics: - result[name] = ALL_METRICS[name](self.y_test, preds) + result[name] = ALL_METRICS[name](self.y_test, scores) return result - # --- anomaly detection --------------------------------------------- + # --- event detection ----------------------------------------------- - def _eval_anomaly_detection(self, model): - # model.predict returns (T_j,) float scores per series - scores = [np.asarray(model.predict(x)) for x in self.X_test] + def _eval_event_detection(self, model): + """Evaluate event detection. + + model.predict(x) must return (N, 2+k) float array per series: + col 0 : predicted start in [0, 1] + col 1 : predicted length in [0, 1] + col 2.. : binary class probabilities / scores in [0, 1] + + y_test is a List of (N, 2+k) event arrays (variable or padded). + """ + preds = [np.asarray(model.predict(x)) for x in self.X_test] result = {} for name in self.metrics: - result[name] = ALL_METRICS[name](self.y_test, scores) + result[name] = ALL_METRICS[name](self.y_test, preds) return result # ------------------------------------------------------------------ diff --git a/solvers/chronos.py b/solvers/chronos.py index fe5c621..b1d74a1 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -102,7 +102,7 @@ class _ChronosEmbedEncoder(UnpooledEncoder): series in the batch. """ - def __init__(self, pipeline: ChronosPipeline): + def __init__(self, pipeline): self.pipeline = pipeline def encode(self, X) -> np.ndarray: @@ -121,7 +121,7 @@ class _ChronosHookEncoder(UnpooledEncoder): indices are allowed (``-1`` = last block). """ - def __init__(self, pipeline: ChronosPipeline, layer: int): + def __init__(self, pipeline, layer: int): self.pipeline = pipeline n_blocks = len(pipeline.model.model.encoder.block) if not -n_blocks <= layer < n_blocks: @@ -159,9 +159,7 @@ def _hook(_module, _inputs, output): return captured["h"].transpose(0, 1).float().cpu().numpy() -def ChronosEncoder( - pipeline: ChronosPipeline, layer: int | None = None -) -> UnpooledEncoder: +def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: """Build a Chronos feature extractor. Parameters diff --git a/solvers/chronos_event.py b/solvers/chronos_event.py new file mode 100644 index 0000000..e72fe77 --- /dev/null +++ b/solvers/chronos_event.py @@ -0,0 +1,326 @@ +"""Chronos-based event-detection solver for the TSFM benchmark. + +Supports: + - event_detection : frozen Chronos T5 encoder + trainable EventHead + +Overview +-------- +Event detection is a structured-prediction task. Each input series x (T, C) +is mapped to a fixed-size set of N=10 span predictions: + + output : (N=10, 2+k) + col 0 : event start, normalised to [0, 1] over T=512 + col 1 : event length, normalised to [0, 1] over T=512 + col 2.. : k binary class probability columns + +y_train / y_test format (padded) +--------------------------------- +Each element is a float32 array of shape (N=10, 2+k). Empty/no-event slots +are represented as all-zero rows. Real event slots satisfy + row[2:].sum() >= 1 +The solver assumes this padded format and reads n_classes from meta +(key "n_classes") and T from meta (key "T", default 512). + +Meta keys expected from the dataset +------------------------------------- + n_classes : int number of binary class columns k + T : int series length (default 512) + +These must be returned by the dataset's get_data() via the ``extra`` dict +that gets forwarded to objective.set_data() and then to get_objective(). + +Architecture (frozen encoder, trained head) +-------------------------------------------- +Chronos is purely univariate. To handle C channels we embed each channel +independently via pipeline.embed() and **mean-pool the (T_tok, D) tensors** +across channels. This is parameter-free and works for any C at inference time. +Alternatives (concat, attention aggregation) are noted in event_detection.py. + +Because Chronos is frozen, all training embeddings are pre-computed once in +set_objective() (untimed) and cached in CPU RAM. The timed run() block only +trains the small EventHead — very fast on H100. + +Hyperparameters (H100 defaults) +--------------------------------- + model_size : "small" (D=512, 46 M params) + model_path : "" empty = download from HuggingFace Hub + set to a local directory path to load offline + e.g. "/path/to/models/chronos_t5_small" + batch_size : 32 + num_epochs : 100 + lr : 3e-4 + weight_decay : 1e-4 + warmup_epochs : 5 (linear warmup, cosine decay thereafter) + num_queries : 10 (= N, fixed) + num_dec_layers : 2 + lambda_cls : 1.0 (weight of class loss vs. position loss) + +Benchopt timing contract +-------------------------- + set_objective() — model loading + embedding pre-computation (UNTIMED) + run() — EventHead training (TIMED) + get_result() — returns {"model": ChronosEventAdapter} +""" + +import math + +import numpy as np +import torch +from benchopt import BaseSolver + +from benchmark_utils.adapters.event_detection import ChronosEventAdapter, EventHead + + +SUPPORTED_TASKS = {"event_detection"} + +# Chronos encoder output dimensions by model size +_CHRONOS_D = { + "tiny": 64, + "mini": 128, + "small": 512, + "base": 768, + "large": 1024, +} + + +def _get_linear_cosine_scheduler(optimizer, warmup_epochs, total_epochs): + """Linear warmup + cosine annealing LR scheduler (epoch-level).""" + def lr_lambda(epoch): + if epoch < warmup_epochs: + return float(epoch + 1) / float(max(1, warmup_epochs)) + progress = float(epoch - warmup_epochs) / float( + max(1, total_epochs - warmup_epochs) + ) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +class Solver(BaseSolver): + """Chronos encoder (frozen) + EventHead event-detection solver. + + See module docstring for full details on the architecture, data format, + and hyperparameter choices. + """ + + name = "Chronos-EventDetection" + + requirements = [ + "pip::chronos-forecasting>=1.4", + "pip::torch", + ] + + sampling_strategy = "run_once" + + parameters = { + "model_size": ["small"], + "model_path": [""], # empty = load from HuggingFace Hub + "batch_size": [32], + "num_epochs": [0, 100], + "lr": [3e-4], + "weight_decay": [1e-4], + "warmup_epochs": [5], + "num_dec_layers": [2], + "lambda_cls": [1.0], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"Chronos-EventDetection does not support task={task!r}" + return False, None + + # ------------------------------------------------------------------ + # set_objective — UNTIMED + # ------------------------------------------------------------------ + + def set_objective(self, X_train, y_train, task, **meta): + """Load Chronos, freeze it, and pre-compute training embeddings. + + Parameters + ---------- + X_train : List[np.ndarray (T, C)] + y_train : List[np.ndarray (N=10, 2+k)] padded event targets + task : str must be "event_detection" + **meta : must include "n_classes" (int) and optionally "T" (int) + """ + import torch as _torch + from chronos import ChronosPipeline + + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + # --- Infer task dimensions --- + self.n_classes = int(meta["n_classes"]) + self.T = int(meta.get("T", 512)) + self.k = self.n_classes # alias + + # --- Device --- + self.device = "cuda" if _torch.cuda.is_available() else "cpu" + + # --- Resolve model identifier --- + # model_path="" (default) → load from HuggingFace Hub + # model_path="/some/local/dir" → load from that local directory + model_id = ( + self.model_path.strip() + if self.model_path.strip() + else f"amazon/chronos-t5-{self.model_size}" + ) + + # --- Load Chronos pipeline (once; cached across dataset configs) --- + should_reload = ( + not hasattr(self, "_pipeline") + or not hasattr(self, "_loaded_model_id") + or self._loaded_model_id != model_id + ) + if should_reload: + self._pipeline = ChronosPipeline.from_pretrained( + model_id, + device_map="auto", + dtype=_torch.bfloat16, + ) + self._loaded_model_id = model_id + print(f"Loaded Chronos checkpoint: {model_id}") + + # --- Freeze encoder --- + for param in self._pipeline.model.parameters(): + param.requires_grad = False + self._pipeline.model.eval() + + # --- Encoder output dimension --- + self.d_model = _CHRONOS_D.get(self.model_size, 512) + + # --- Pre-compute training embeddings (frozen encoder → one-time) --- + print( + f"Pre-computing embeddings for {len(X_train)} training series " + f"(C={X_train[0].shape[1]}, T={self.T}) ..." + ) + self._Z_train = [] # List of (T_tok, D) CPU tensors + with _torch.no_grad(): + for x in X_train: + emb = self._embed_series(x) # (T_tok, D) float32 on CPU + self._Z_train.append(emb) + + print("Embedding pre-computation complete.") + + # ------------------------------------------------------------------ + # run — TIMED + # ------------------------------------------------------------------ + + def run(self, _): + """Train the EventHead on top of cached Chronos embeddings.""" + import torch as _torch + + head = EventHead( + d_model=self.d_model, + n_classes=self.n_classes, + num_queries=10, + num_decoder_layers=self.num_dec_layers, + nhead=8, + ).to(self.device) + head.train() + + optimizer = _torch.optim.AdamW( + head.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + scheduler = _get_linear_cosine_scheduler( + optimizer, self.warmup_epochs, self.num_epochs + ) + + N_train = len(self._Z_train) + use_amp = self.device == "cuda" + scaler = _torch.amp.GradScaler("cuda", enabled=use_amp) + + for epoch in range(self.num_epochs): + indices = np.random.permutation(N_train) + epoch_loss = 0.0 + num_batches = 0 + + for batch_start in range(0, N_train, self.batch_size): + batch_idx = indices[batch_start: batch_start + self.batch_size] + + # --- Collate: pad to same T_tok (all same T → same T_tok) --- + embs = [self._Z_train[i] for i in batch_idx] + max_ttok = max(e.shape[0] for e in embs) + D = embs[0].shape[1] + B = len(embs) + + memory = _torch.zeros(B, max_ttok, D, dtype=_torch.float32) + for bi, e in enumerate(embs): + memory[bi, : e.shape[0]] = e + memory = memory.to(self.device) + + # --- Targets --- + y_batch = _torch.tensor( + np.stack([self.y_train[i] for i in batch_idx]), + dtype=_torch.float32, + device=self.device, + ) # (B, N, 2+k) + + # --- Forward + loss --- + optimizer.zero_grad() + with _torch.amp.autocast("cuda", dtype=_torch.bfloat16, + enabled=use_amp): + pos_logits, cls_logits = head(memory) + loss = head.compute_loss( + pos_logits, cls_logits, y_batch, + lambda_cls=self.lambda_cls, + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + _torch.nn.utils.clip_grad_norm_(head.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + + epoch_loss += loss.item() + num_batches += 1 + + scheduler.step() + + if (epoch + 1) % 10 == 0 or epoch == 0: + avg = epoch_loss / max(num_batches, 1) + lr_now = scheduler.get_last_lr()[0] + print( + f" Epoch {epoch + 1:3d}/{self.num_epochs} | " + f"loss={avg:.4f} | lr={lr_now:.2e}" + ) + + head.eval() + self._adapter = ChronosEventAdapter( + pipeline=self._pipeline, + head=head, + device=self.device, + n_classes=self.n_classes, + T=self.T, + ) + + # ------------------------------------------------------------------ + # get_result + # ------------------------------------------------------------------ + + def get_result(self): + return {"model": self._adapter} + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _embed_series(self, x: np.ndarray) -> torch.Tensor: + """Embed a (T, C) array via frozen Chronos, mean-pool across C. + + Returns + ------- + Tensor (T_tok, D) float32 on CPU + """ + C = x.shape[1] + channel_embs = [] + for c in range(C): + ctx = torch.tensor(x[:, c], dtype=torch.float32) + emb, _ = self._pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) + channel_embs.append(emb.squeeze(0).float().cpu()) + stacked = torch.stack(channel_embs, dim=0) # (C, T_tok, D) + return stacked.mean(dim=0) # (T_tok, D) From a10a35995b7ee16951845f06570c673229a17b66 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 11:36:46 +0200 Subject: [PATCH 04/15] adding unit tests for map_iou and f1 score metric for event detection --- benchmark_utils/metrics.py | 95 +++++++++ tests/benchmark_utils/test_metrics_map_iou.py | 185 ++++++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 tests/benchmark_utils/test_metrics_map_iou.py diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index ae62e3b..8a60a41 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -254,6 +254,100 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): return float(np.mean(valid)) if valid else float("nan") +def _collect_tp_pairs(gt_by_series, preds, iou_threshold): + """Greedy IoU matching. Returns list of ([s,w]_pred, [s,w]_gt) for each TP.""" + matched = {i: [False] * len(boxes) for i, boxes in gt_by_series.items()} + pairs = [] + for i, s, w, _ in preds: + best_iou, best_gi = 0.0, -1 + for gi, (gs, gw) in enumerate(gt_by_series.get(i, [])): + iou = _iou_1d(s, w, gs, gw) + if iou > best_iou: + best_iou, best_gi = iou, gi + if best_iou >= iou_threshold and best_gi >= 0 and not matched[i][best_gi]: + gs, gw = gt_by_series[i][best_gi] + pairs.append(([s, w], [gs, gw])) + matched[i][best_gi] = True + return pairs + + +def _match_events(gt_by_series, preds, iou_threshold): + """Greedy IoU matching (highest-score first). Returns (tp, fp, fn).""" + matched = {i: [False] * len(boxes) for i, boxes in gt_by_series.items()} + n_gt = sum(len(b) for b in gt_by_series.values()) + tp = fp = 0 + for i, s, w, _ in preds: + best_iou, best_gi = 0.0, -1 + for gi, (gs, gw) in enumerate(gt_by_series.get(i, [])): + iou = _iou_1d(s, w, gs, gw) + if iou > best_iou: + best_iou, best_gi = iou, gi + if best_iou >= iou_threshold and best_gi >= 0 and not matched[i][best_gi]: + tp += 1 + matched[i][best_gi] = True + else: + fp += 1 + return tp, fp, n_gt - tp + + +def f1_det(y_true, y_pred, iou_threshold=0.5, score_threshold=None): + """Macro F1 for event detection. + + Parameters + ---------- + y_true, y_pred : same format as map_iou + iou_threshold : minimum IoU to count a match (default 0.5) + score_threshold: fixed class score threshold; if None, sweep to maximise F1 + per class (oracle — for benchmarking purposes only) + """ + if not y_true: + return float("nan") + + n_classes = y_true[0].shape[1] - 2 + f1s = [] + + for k in range(n_classes): + gt_by_series = {} + n_gt = 0 + for i, gt in enumerate(y_true): + boxes = [(row[0], row[1]) for row in gt + if len(gt) > 0 and np.argmax(row[2:]) == k] + gt_by_series[i] = boxes + n_gt += len(boxes) + + if n_gt == 0: + f1s.append(float("nan")) + continue + + all_preds = [] + for i, pred in enumerate(y_pred): + for row in pred: + all_preds.append((i, row[0], row[1], float(row[2 + k]))) + all_preds.sort(key=lambda x: -x[3]) + + if score_threshold is None: + scores = np.array([p[3] for p in all_preds]) + thresholds = ( + np.percentile(scores, np.arange(0, 100, 1)) + if len(scores) else [0.5] + ) + best_f1 = 0.0 + for thr in thresholds: + preds = [p for p in all_preds if p[3] >= thr] + tp, fp, fn = _match_events(gt_by_series, preds, iou_threshold) + denom = 2 * tp + fp + fn + best_f1 = max(best_f1, (2 * tp / denom) if denom > 0 else 0.0) + f1s.append(best_f1) + else: + preds = [p for p in all_preds if p[3] >= score_threshold] + tp, fp, fn = _match_events(gt_by_series, preds, iou_threshold) + denom = 2 * tp + fp + fn + f1s.append((2 * tp / denom) if denom > 0 else 0.0) + + valid = [f for f in f1s if not np.isnan(f)] + return float(np.mean(valid)) if valid else float("nan") + + # --------------------------------------------------------------------------- # Registry: maps metric name → function # --------------------------------------------------------------------------- @@ -280,6 +374,7 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): EVENT_METRICS = { "map_iou": map_iou, + "f1_det": f1_det, } ALL_METRICS = {**FORECASTING_METRICS, **CLASSIFICATION_METRICS, **AD_METRICS, diff --git a/tests/benchmark_utils/test_metrics_map_iou.py b/tests/benchmark_utils/test_metrics_map_iou.py new file mode 100644 index 0000000..4ad55a7 --- /dev/null +++ b/tests/benchmark_utils/test_metrics_map_iou.py @@ -0,0 +1,185 @@ +"""Unit tests for map_iou, f1_det and their helpers.""" + +import math + +import numpy as np +import pytest + +from benchmark_utils.metrics import _ap_from_tp_fp, _iou_1d, f1_det, map_iou + + +def _evt(start, width, *class_cols): + """Single-row event array: [start, width, *class_cols].""" + return np.array([[start, width, *class_cols]], dtype=float) + + +def _no_evts(n_classes): + return np.zeros((0, 2 + n_classes)) + + +# --------------------------------------------------------------------------- +# _iou_1d +# --------------------------------------------------------------------------- + +def test_iou_identical(): + assert _iou_1d(0.0, 0.5, 0.0, 0.5) == pytest.approx(1.0) + +def test_iou_no_overlap(): + assert _iou_1d(0.0, 0.3, 0.4, 0.3) == pytest.approx(0.0) + +def test_iou_adjacent(): + # Segments touch at a single point — inter = 0 + assert _iou_1d(0.0, 0.5, 0.5, 0.5) == pytest.approx(0.0) + +def test_iou_partial(): + # [0, 1) and [0.5, 1.5) → inter=0.5, union=1.5 + assert _iou_1d(0.0, 1.0, 0.5, 1.0) == pytest.approx(0.5 / 1.5) + +def test_iou_contained(): + # [0, 1) contains [0.2, 0.6) → inter=0.4, union=1.0 + assert _iou_1d(0.0, 1.0, 0.2, 0.4) == pytest.approx(0.4 / 1.0) + +def test_iou_zero_width_returns_zero(): + assert _iou_1d(0.5, 0.0, 0.5, 0.0) == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# _ap_from_tp_fp +# --------------------------------------------------------------------------- + +def test_ap_no_gt_returns_nan(): + assert math.isnan(_ap_from_tp_fp(np.array([1.0]), np.array([0.0]), 0)) + +def test_ap_all_tp(): + ap = _ap_from_tp_fp(np.array([1.0, 1.0]), np.array([0.0, 0.0]), n_gt=2) + assert ap == pytest.approx(1.0) + +def test_ap_all_fp(): + ap = _ap_from_tp_fp(np.array([0.0, 0.0]), np.array([1.0, 1.0]), n_gt=2) + assert ap == pytest.approx(0.0) + +def test_ap_tp_then_fp(): + # Rank 1: TP (precision=1, recall=0.5), Rank 2: FP (recall still 0.5) + # AP = Δrecall × precision = 0.5 × 1.0 = 0.5 + ap = _ap_from_tp_fp(np.array([1.0, 0.0]), np.array([0.0, 1.0]), n_gt=2) + assert ap == pytest.approx(0.5) + + +# --------------------------------------------------------------------------- +# map_iou +# --------------------------------------------------------------------------- + +def test_map_empty_y_true_returns_nan(): + assert math.isnan(map_iou([], [])) + +def test_map_no_predictions_returns_zero(): + assert map_iou([_evt(0.0, 0.5, 1)], [_no_evts(1)]) == pytest.approx(0.0) + +def test_map_no_gt_events_returns_nan(): + # Series has no GT events — AP is undefined + assert math.isnan(map_iou([_no_evts(1)], [_evt(0.0, 0.5, 0.9)])) + +def test_map_perfect_match(): + assert map_iou([_evt(0.1, 0.4, 1)], [_evt(0.1, 0.4, 1.0)]) == pytest.approx(1.0) + +def test_map_no_overlap(): + assert map_iou([_evt(0.0, 0.3, 1)], [_evt(0.7, 0.3, 1.0)]) == pytest.approx(0.0) + +def test_map_overlap_below_threshold_is_miss(): + # IoU ≈ 0.333 < 0.5 + assert map_iou([_evt(0.0, 1.0, 1)], [_evt(0.5, 1.0, 1.0)], iou_threshold=0.5) == pytest.approx(0.0) + +def test_map_overlap_above_custom_threshold_is_hit(): + # Same IoU ≈ 0.333 > 0.3 + assert map_iou([_evt(0.0, 1.0, 1)], [_evt(0.5, 1.0, 1.0)], iou_threshold=0.3) == pytest.approx(1.0) + +def test_map_duplicate_pred_only_first_matched(): + # Two identical predictions for one GT: first TP, second FP → AP=1.0 + preds = [np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 0.9]])] + assert map_iou([_evt(0.0, 0.5, 1)], preds) == pytest.approx(1.0) + +def test_map_predictions_ranked_by_score(): + # High-score pred misses (FP at rank 1), low-score pred hits (TP at rank 2) + # AP = Δrecall × precision at rank 2 = 1.0 × 0.5 = 0.5 + preds = [np.array([[0.8, 0.1, 0.9], [0.0, 0.5, 0.5]])] + assert map_iou([_evt(0.0, 0.5, 1)], preds) == pytest.approx(0.5) + +def test_map_two_classes_both_perfect(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert map_iou([gt], [pred]) == pytest.approx(1.0) + +def test_map_two_classes_one_missed(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([ + [0.0, 0.3, 1.0, 0.0], # class 0: perfect hit + [0.0, 0.1, 0.0, 1.0], # class 1: no overlap with GT at [0.5, 0.3] + ]) + assert map_iou([gt], [pred]) == pytest.approx(0.5) + +def test_map_multi_series_both_matched(): + y_true = [_evt(0.0, 0.5, 1), _evt(0.2, 0.4, 1)] + y_pred = [_evt(0.0, 0.5, 1.0), _evt(0.2, 0.4, 1.0)] + assert map_iou(y_true, y_pred) == pytest.approx(1.0) + +def test_map_prediction_in_wrong_series_is_fp(): + # GT in series 0, pred only in series 1 — should not match + y_true = [_evt(0.0, 0.5, 1), _no_evts(1)] + y_pred = [_no_evts(1), _evt(0.0, 0.5, 1.0)] + assert map_iou(y_true, y_pred) == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# f1_det +# --------------------------------------------------------------------------- + +def test_f1det_empty_y_true_returns_nan(): + assert math.isnan(f1_det([], [])) + +def test_f1det_no_gt_events_returns_nan(): + assert math.isnan(f1_det([_no_evts(1)], [_evt(0.0, 0.5, 0.9)])) + +def test_f1det_perfect_match_fixed_threshold(): + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.9)], score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_no_overlap_fixed_threshold(): + # TP=0, FP=1, FN=1 → F1=0 + assert f1_det([_evt(0.0, 0.3, 1)], [_evt(0.7, 0.3, 0.9)], score_threshold=0.5) == pytest.approx(0.0) + +def test_f1det_prediction_below_threshold_becomes_fn(): + # Pred score 0.3 < threshold 0.5 → filtered out → TP=0, FP=0, FN=1 → F1=0 + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.3)], score_threshold=0.5) == pytest.approx(0.0) + +def test_f1det_oracle_finds_best_threshold(): + # Low-score pred (0.2) would be filtered by fixed threshold=0.5 but oracle should find it + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.2)], score_threshold=None) == pytest.approx(1.0) + +def test_f1det_duplicate_pred_second_is_fp(): + # TP=1, FP=1, FN=0 → F1 = 2/(2+1+0) = 2/3 + preds = [np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 0.9]])] + assert f1_det([_evt(0.0, 0.5, 1)], preds, score_threshold=0.5) == pytest.approx(2 / 3) + +def test_f1det_two_classes_both_perfect(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert f1_det([gt], [pred], score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_two_classes_one_missed(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([ + [0.0, 0.3, 1.0, 0.0], # class 0: perfect hit + [0.0, 0.1, 0.0, 0.9], # class 1: no overlap with GT at [0.5, 0.3] + ]) + # class 0: F1=1.0, class 1: TP=0 FP=1 FN=1 → F1=0 → mean=0.5 + assert f1_det([gt], [pred], score_threshold=0.5) == pytest.approx(0.5) + +def test_f1det_multi_series_both_matched(): + y_true = [_evt(0.0, 0.5, 1), _evt(0.2, 0.4, 1)] + y_pred = [_evt(0.0, 0.5, 0.9), _evt(0.2, 0.4, 0.8)] + assert f1_det(y_true, y_pred, score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_prediction_in_wrong_series_is_fp(): + y_true = [_evt(0.0, 0.5, 1), _no_evts(1)] + y_pred = [_no_evts(1), _evt(0.0, 0.5, 0.9)] + # TP=0, FP=1, FN=1 → F1=0 + assert f1_det(y_true, y_pred, score_threshold=0.5) == pytest.approx(0.0) From 4bd940ab80ca515f5d4c4a54873c48cd34b57ce1 Mon Sep 17 00:00:00 2001 From: Chenwei Wan Date: Fri, 29 May 2026 12:07:45 +0200 Subject: [PATCH 05/15] ENH rename event_class_f1 to event_iou_f1 with matching/mode options Add matching_strategy ("greedy"/"hungarian") and mode ("micro"/"macro") parameters. Greedy now sorts predictions by descending confidence before claiming the highest-IoU free ground-truth span. Penalise unmatched and duplicate predictions as false positives over their active class columns. Co-Authored-By: Claude Opus 4.8 (1M context) --- benchmark_utils/metrics.py | 205 ++++++++++++++++++++++++++++++------- 1 file changed, 166 insertions(+), 39 deletions(-) diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index 2eb010b..9e5df7e 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -329,30 +329,167 @@ def event_span_iou(y_true, y_pred, iou_threshold=0.5): return float(np.mean(f1_scores)) -def event_class_f1(y_true, y_pred, iou_threshold=0.5): - """Micro-F1 over binary class columns on IoU-matched event slots. +def _match_spans(gt_spans, pr_spans, iou_threshold, matching_strategy): + """Match predicted spans to ground-truth spans by 1-D IoU. - For each predicted span that is IoU-matched to a ground-truth span, we - threshold class probabilities at 0.5 and compute micro-F1 over the k - binary class columns. Unmatched ground-truth spans count as false negatives - for all their active classes. + Each ground-truth span and each prediction is matched at most once; pairs + whose IoU is below ``iou_threshold`` are discarded so that unmatched and + duplicate predictions are left out of the matching. Parameters ---------- - y_true : List[np.ndarray (N=10, 2+k)] - y_pred : List[np.ndarray (N=10, 2+k)] + gt_spans : np.ndarray (G, 2+k) + Ground-truth spans (cols 0-1 = start/width, cols 2: = class columns). + pr_spans : np.ndarray (P, 2+k) + Predicted spans, same layout; class columns hold scores in [0, 1]. + iou_threshold : float + Minimum IoU to accept a match. + matching_strategy : {"greedy", "hungarian"} + ``"greedy"`` sorts predictions by descending confidence and assigns + each one to its highest-IoU still-free ground-truth span. + ``"hungarian"`` solves the maximum-IoU bipartite assignment globally. + + Returns + ------- + List[Tuple[int, int]] + ``(gi, pi)`` index pairs of matched ground-truth and prediction spans. + """ + G = gt_spans.shape[0] + P = pr_spans.shape[0] + if G == 0 or P == 0: + return [] + + # IoU matrix: rows = predictions, cols = ground-truth spans. + iou = np.zeros((P, G)) + for pi in range(P): + for gi in range(G): + iou[pi, gi] = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + + if matching_strategy == "greedy": + # Sort predictions by descending confidence (max class score) so that + # the most confident prediction claims its best ground-truth first. + order = np.argsort(-pr_spans[:, 2:].max(axis=1)) + matched_gt = set() + pairs = [] + for pi in order: + best_iou = 0.0 + best_gi = -1 + for gi in range(G): + if gi in matched_gt: + continue + if iou[pi, gi] > best_iou: + best_iou = iou[pi, gi] + best_gi = gi + # Accept the match only if the best free ground-truth clears the + # threshold; the chosen ground-truth is then marked as occupied. + if best_gi >= 0 and best_iou >= iou_threshold: + matched_gt.add(best_gi) + pairs.append((best_gi, pi)) + return pairs + + if matching_strategy == "hungarian": + # Maximum-weight bipartite matching on the IoU matrix (negated for the + # minimisation solver), keeping only pairs above the threshold. + from scipy.optimize import linear_sum_assignment + + pr_idx, gt_idx = linear_sum_assignment(-iou) + return [(int(gi), int(pi)) for pi, gi in zip(pr_idx, gt_idx) + if iou[pi, gi] >= iou_threshold] + + raise ValueError( + f"Unknown matching_strategy={matching_strategy!r}; " + "expected 'greedy' or 'hungarian'." + ) + + +def _f1_from_class_counts(tp, fp, fn, mode): + """Combine per-class TP/FP/FN counts into a single F1 score. + + Parameters + ---------- + tp, fp, fn : np.ndarray (k,) + Per-class true-positive, false-positive and false-negative counts. + mode : {"micro", "macro"} + ``"micro"`` pools counts across classes before computing one F1. + ``"macro"`` computes per-class F1 then averages them equally. + """ + if mode == "micro": + tp_s, fp_s, fn_s = tp.sum(), fp.sum(), fn.sum() + precision = tp_s / (tp_s + fp_s) if (tp_s + fp_s) > 0 else 0.0 + recall = tp_s / (tp_s + fn_s) if (tp_s + fn_s) > 0 else 0.0 + if precision + recall > 0: + return float(2 * precision * recall / (precision + recall)) + return 0.0 + + if mode == "macro": + f1s = [] + for k in range(len(tp)): + denom_p = tp[k] + fp[k] + denom_r = tp[k] + fn[k] + precision = tp[k] / denom_p if denom_p > 0 else 0.0 + recall = tp[k] / denom_r if denom_r > 0 else 0.0 + if precision + recall > 0: + f1s.append(2 * precision * recall / (precision + recall)) + else: + f1s.append(0.0) + return float(np.mean(f1s)) if f1s else 0.0 + + raise ValueError( + f"Unknown mode={mode!r}; expected 'micro' or 'macro'." + ) + + +def event_iou_f1(y_true, y_pred, iou_threshold=0.5, + matching_strategy="greedy", mode="micro"): + """F1 over binary class columns on IoU-matched event spans. + + Predicted spans are matched to ground-truth spans by 1-D IoU using the + requested strategy. For each matched pair, class probabilities are + thresholded at 0.5 and contribute true positives / false positives / false + negatives over the k binary class columns. Unmatched ground-truth spans + count as false negatives for all their active classes, while unmatched and + duplicate predictions count as false positives for all their active + classes, so both kinds of error are penalised. + + Parameters + ---------- + y_true : List[np.ndarray (N, 2+k)] + Ground-truth padded event targets. All-zero class columns = empty slot. + y_pred : List[np.ndarray (N, 2+k)] + Predicted outputs. Cols 0-1 = start/width, cols 2: = class scores. iou_threshold : float + Minimum IoU to accept a span match (default 0.5). + matching_strategy : {"greedy", "hungarian"} + Span matching strategy (default "greedy"). See :func:`_match_spans`. + mode : {"micro", "macro"} + Class-averaging mode for the final F1 (default "micro"). Returns ------- - float — micro-F1 over class columns + float — class F1 score aggregated over all series. """ - tp_total = fp_total = fn_total = 0 + # Infer the number of class columns from the first 2-D sample available. + n_classes = None + for arr in list(y_true) + list(y_pred): + a = np.asarray(arr) + if a.ndim == 2 and a.shape[1] > 2: + n_classes = a.shape[1] - 2 + break + if n_classes is None: + return float("nan") + + tp = np.zeros(n_classes) + fp = np.zeros(n_classes) + fn = np.zeros(n_classes) for gt, pr in zip(y_true, y_pred): gt = np.asarray(gt) pr = np.asarray(pr) + # Keep only occupied ground-truth slots and confident predictions. gt_mask = gt[:, 2:].sum(axis=1) > 0 pr_mask = pr[:, 2:].max(axis=1) > 0.5 @@ -362,41 +499,31 @@ def event_class_f1(y_true, y_pred, iou_threshold=0.5): G = gt_spans.shape[0] P = pr_spans.shape[0] - matched_gt = {} - matched_pr = set() - for gi in range(G): - best_iou = 0.0 - best_pi = -1 - for pi in range(P): - if pi in matched_pr: - continue - iou = _span_iou( - pr_spans[pi, 0], pr_spans[pi, 1], - gt_spans[gi, 0], gt_spans[gi, 1], - ) - if iou > best_iou: - best_iou = iou - best_pi = pi - if best_iou >= iou_threshold and best_pi >= 0: - matched_gt[gi] = best_pi - matched_pr.add(best_pi) + pairs = _match_spans(gt_spans, pr_spans, iou_threshold, + matching_strategy) + matched_gt = {gi for gi, _ in pairs} + matched_pr = {pi for _, pi in pairs} - for gi, pi in matched_gt.items(): + # Matched pairs: per-class agreement on the thresholded class columns. + for gi, pi in pairs: gt_cls = (gt_spans[gi, 2:] > 0.5).astype(int) pr_cls = (pr_spans[pi, 2:] > 0.5).astype(int) - tp_total += int((gt_cls & pr_cls).sum()) - fp_total += int(((1 - gt_cls) & pr_cls).sum()) - fn_total += int((gt_cls & (1 - pr_cls)).sum()) + tp += gt_cls & pr_cls + fp += (1 - gt_cls) & pr_cls + fn += gt_cls & (1 - pr_cls) + # Unmatched ground-truth spans: every active class is a false negative. for gi in range(G): if gi not in matched_gt: - fn_total += int((gt_spans[gi, 2:] > 0.5).sum()) + fn += (gt_spans[gi, 2:] > 0.5).astype(int) + + # Unmatched / duplicate predictions: every active class is a false + # positive (penalty for over-prediction). + for pi in range(P): + if pi not in matched_pr: + fp += (pr_spans[pi, 2:] > 0.5).astype(int) - precision = tp_total / (tp_total + fp_total) if (tp_total + fp_total) > 0 else 0.0 - recall = tp_total / (tp_total + fn_total) if (tp_total + fn_total) > 0 else 0.0 - if precision + recall > 0: - return float(2 * precision * recall / (precision + recall)) - return 0.0 + return _f1_from_class_counts(tp, fp, fn, mode) # --------------------------------------------------------------------------- @@ -426,7 +553,7 @@ def event_class_f1(y_true, y_pred, iou_threshold=0.5): EVENT_METRICS = { "map_iou": map_iou, "event_span_iou": event_span_iou, - "event_class_f1": event_class_f1, + "event_iou_f1": event_iou_f1, } ALL_METRICS = { From 705e6f76db57f32fa4210120dc0fc109b37ecc35 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 12:14:57 +0200 Subject: [PATCH 06/15] adding truncation in dataset --- datasets/mitdb.py | 26 ++++++++++++++++++++++---- objective.py | 5 +++-- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/datasets/mitdb.py b/datasets/mitdb.py index f8a06f5..61e9ccc 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -12,12 +12,14 @@ class 4 Q — Unknown / pacemaker artefact Data contract output -------------------- X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) -y_train : List[np.ndarray (N_i, 2+K)] one row per beat event: +y_train : List[np.ndarray (N, 2+K)] one row per beat event, zero-padded + to N = max events across all segments: col 0 start (normalised, 0–1) col 1 width (normalised, 0–1) cols 2… one-hot class vector (K) + all-zero row = empty/padding slot X_test : List[np.ndarray (T_j, 2)] test portions -y_test : List[np.ndarray (N_j, 2+K)] same format +y_test : List[np.ndarray (N, 2+K)] same format, same N task : "event_detection" metrics : ["map_iou"] extra : n_classes (int) K above @@ -180,11 +182,27 @@ def get_data(self): self.beat_window, self.n_classes, )) + # Pad all event arrays to a uniform N so solvers can np.stack them. + # N = max events in any single segment across train and test. + all_y = y_train + y_test + max_n = max((y.shape[0] for y in all_y), default=0) + n_cols = 2 + self.n_classes + + def _pad(arrays): + out = [] + for y in arrays: + n = y.shape[0] + if n < max_n: + pad = np.zeros((max_n - n, n_cols), dtype=np.float32) + y = np.concatenate([y, pad], axis=0) + out.append(y) + return out + return dict( X_train=X_train, - y_train=y_train, + y_train=_pad(y_train), X_test=X_test, - y_test=y_test, + y_test=_pad(y_test), task="event_detection", metrics=["map_iou"], n_classes=self.n_classes, diff --git a/objective.py b/objective.py index 0f0474a..d815b48 100644 --- a/objective.py +++ b/objective.py @@ -35,8 +35,9 @@ extra n_classes (int) anomaly_detection y_train None y_test List[(T_j,)] int point-level binary labels -event_detection y_train List[(N, 2+K)] float object-detection boxes - y_test List[(N, 2+K)] float object-detection boxes +event_detection y_train List[(N, 2+K)] float object-detection boxes, + zero-padded to uniform N + y_test List[(N, 2+K)] float same format, same N extra n_classes (int), T (int, default 512) Each row of the (N, 2+k) array is one event slot: From b3dea145f86515f415cacd4a8f8874d2f3d20249 Mon Sep 17 00:00:00 2001 From: Chenwei Wan Date: Fri, 29 May 2026 12:07:45 +0200 Subject: [PATCH 07/15] ENH rename event_class_f1 to event_iou_f1 with matching/mode options Add matching_strategy ("greedy"/"hungarian") and mode ("micro"/"macro") parameters. Greedy now sorts predictions by descending confidence before claiming the highest-IoU free ground-truth span. Penalise unmatched and duplicate predictions as false positives over their active class columns. Co-Authored-By: Claude Opus 4.8 (1M context) --- benchmark_utils/metrics.py | 208 +++++++++++++++++++++++++++++-------- 1 file changed, 166 insertions(+), 42 deletions(-) diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index 17e0b38..06fe7e0 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -423,30 +423,167 @@ def event_span_iou(y_true, y_pred, iou_threshold=0.5): return float(np.mean(f1_scores)) -def event_class_f1(y_true, y_pred, iou_threshold=0.5): - """Micro-F1 over binary class columns on IoU-matched event slots. +def _match_spans(gt_spans, pr_spans, iou_threshold, matching_strategy): + """Match predicted spans to ground-truth spans by 1-D IoU. - For each predicted span that is IoU-matched to a ground-truth span, we - threshold class probabilities at 0.5 and compute micro-F1 over the k - binary class columns. Unmatched ground-truth spans count as false negatives - for all their active classes. + Each ground-truth span and each prediction is matched at most once; pairs + whose IoU is below ``iou_threshold`` are discarded so that unmatched and + duplicate predictions are left out of the matching. Parameters ---------- - y_true : List[np.ndarray (N=10, 2+k)] - y_pred : List[np.ndarray (N=10, 2+k)] + gt_spans : np.ndarray (G, 2+k) + Ground-truth spans (cols 0-1 = start/width, cols 2: = class columns). + pr_spans : np.ndarray (P, 2+k) + Predicted spans, same layout; class columns hold scores in [0, 1]. + iou_threshold : float + Minimum IoU to accept a match. + matching_strategy : {"greedy", "hungarian"} + ``"greedy"`` sorts predictions by descending confidence and assigns + each one to its highest-IoU still-free ground-truth span. + ``"hungarian"`` solves the maximum-IoU bipartite assignment globally. + + Returns + ------- + List[Tuple[int, int]] + ``(gi, pi)`` index pairs of matched ground-truth and prediction spans. + """ + G = gt_spans.shape[0] + P = pr_spans.shape[0] + if G == 0 or P == 0: + return [] + + # IoU matrix: rows = predictions, cols = ground-truth spans. + iou = np.zeros((P, G)) + for pi in range(P): + for gi in range(G): + iou[pi, gi] = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + + if matching_strategy == "greedy": + # Sort predictions by descending confidence (max class score) so that + # the most confident prediction claims its best ground-truth first. + order = np.argsort(-pr_spans[:, 2:].max(axis=1)) + matched_gt = set() + pairs = [] + for pi in order: + best_iou = 0.0 + best_gi = -1 + for gi in range(G): + if gi in matched_gt: + continue + if iou[pi, gi] > best_iou: + best_iou = iou[pi, gi] + best_gi = gi + # Accept the match only if the best free ground-truth clears the + # threshold; the chosen ground-truth is then marked as occupied. + if best_gi >= 0 and best_iou >= iou_threshold: + matched_gt.add(best_gi) + pairs.append((best_gi, pi)) + return pairs + + if matching_strategy == "hungarian": + # Maximum-weight bipartite matching on the IoU matrix (negated for the + # minimisation solver), keeping only pairs above the threshold. + from scipy.optimize import linear_sum_assignment + + pr_idx, gt_idx = linear_sum_assignment(-iou) + return [(int(gi), int(pi)) for pi, gi in zip(pr_idx, gt_idx) + if iou[pi, gi] >= iou_threshold] + + raise ValueError( + f"Unknown matching_strategy={matching_strategy!r}; " + "expected 'greedy' or 'hungarian'." + ) + + +def _f1_from_class_counts(tp, fp, fn, mode): + """Combine per-class TP/FP/FN counts into a single F1 score. + + Parameters + ---------- + tp, fp, fn : np.ndarray (k,) + Per-class true-positive, false-positive and false-negative counts. + mode : {"micro", "macro"} + ``"micro"`` pools counts across classes before computing one F1. + ``"macro"`` computes per-class F1 then averages them equally. + """ + if mode == "micro": + tp_s, fp_s, fn_s = tp.sum(), fp.sum(), fn.sum() + precision = tp_s / (tp_s + fp_s) if (tp_s + fp_s) > 0 else 0.0 + recall = tp_s / (tp_s + fn_s) if (tp_s + fn_s) > 0 else 0.0 + if precision + recall > 0: + return float(2 * precision * recall / (precision + recall)) + return 0.0 + + if mode == "macro": + f1s = [] + for k in range(len(tp)): + denom_p = tp[k] + fp[k] + denom_r = tp[k] + fn[k] + precision = tp[k] / denom_p if denom_p > 0 else 0.0 + recall = tp[k] / denom_r if denom_r > 0 else 0.0 + if precision + recall > 0: + f1s.append(2 * precision * recall / (precision + recall)) + else: + f1s.append(0.0) + return float(np.mean(f1s)) if f1s else 0.0 + + raise ValueError( + f"Unknown mode={mode!r}; expected 'micro' or 'macro'." + ) + + +def event_iou_f1(y_true, y_pred, iou_threshold=0.5, + matching_strategy="greedy", mode="micro"): + """F1 over binary class columns on IoU-matched event spans. + + Predicted spans are matched to ground-truth spans by 1-D IoU using the + requested strategy. For each matched pair, class probabilities are + thresholded at 0.5 and contribute true positives / false positives / false + negatives over the k binary class columns. Unmatched ground-truth spans + count as false negatives for all their active classes, while unmatched and + duplicate predictions count as false positives for all their active + classes, so both kinds of error are penalised. + + Parameters + ---------- + y_true : List[np.ndarray (N, 2+k)] + Ground-truth padded event targets. All-zero class columns = empty slot. + y_pred : List[np.ndarray (N, 2+k)] + Predicted outputs. Cols 0-1 = start/width, cols 2: = class scores. iou_threshold : float + Minimum IoU to accept a span match (default 0.5). + matching_strategy : {"greedy", "hungarian"} + Span matching strategy (default "greedy"). See :func:`_match_spans`. + mode : {"micro", "macro"} + Class-averaging mode for the final F1 (default "micro"). Returns ------- - float — micro-F1 over class columns + float — class F1 score aggregated over all series. """ - tp_total = fp_total = fn_total = 0 + # Infer the number of class columns from the first 2-D sample available. + n_classes = None + for arr in list(y_true) + list(y_pred): + a = np.asarray(arr) + if a.ndim == 2 and a.shape[1] > 2: + n_classes = a.shape[1] - 2 + break + if n_classes is None: + return float("nan") + + tp = np.zeros(n_classes) + fp = np.zeros(n_classes) + fn = np.zeros(n_classes) for gt, pr in zip(y_true, y_pred): gt = np.asarray(gt) pr = np.asarray(pr) + # Keep only occupied ground-truth slots and confident predictions. gt_mask = gt[:, 2:].sum(axis=1) > 0 pr_mask = pr[:, 2:].max(axis=1) > 0.5 @@ -456,41 +593,31 @@ def event_class_f1(y_true, y_pred, iou_threshold=0.5): G = gt_spans.shape[0] P = pr_spans.shape[0] - matched_gt = {} - matched_pr = set() - for gi in range(G): - best_iou = 0.0 - best_pi = -1 - for pi in range(P): - if pi in matched_pr: - continue - iou = _span_iou( - pr_spans[pi, 0], pr_spans[pi, 1], - gt_spans[gi, 0], gt_spans[gi, 1], - ) - if iou > best_iou: - best_iou = iou - best_pi = pi - if best_iou >= iou_threshold and best_pi >= 0: - matched_gt[gi] = best_pi - matched_pr.add(best_pi) + pairs = _match_spans(gt_spans, pr_spans, iou_threshold, + matching_strategy) + matched_gt = {gi for gi, _ in pairs} + matched_pr = {pi for _, pi in pairs} - for gi, pi in matched_gt.items(): + # Matched pairs: per-class agreement on the thresholded class columns. + for gi, pi in pairs: gt_cls = (gt_spans[gi, 2:] > 0.5).astype(int) pr_cls = (pr_spans[pi, 2:] > 0.5).astype(int) - tp_total += int((gt_cls & pr_cls).sum()) - fp_total += int(((1 - gt_cls) & pr_cls).sum()) - fn_total += int((gt_cls & (1 - pr_cls)).sum()) + tp += gt_cls & pr_cls + fp += (1 - gt_cls) & pr_cls + fn += gt_cls & (1 - pr_cls) + # Unmatched ground-truth spans: every active class is a false negative. for gi in range(G): if gi not in matched_gt: - fn_total += int((gt_spans[gi, 2:] > 0.5).sum()) + fn += (gt_spans[gi, 2:] > 0.5).astype(int) + + # Unmatched / duplicate predictions: every active class is a false + # positive (penalty for over-prediction). + for pi in range(P): + if pi not in matched_pr: + fp += (pr_spans[pi, 2:] > 0.5).astype(int) - precision = tp_total / (tp_total + fp_total) if (tp_total + fp_total) > 0 else 0.0 - recall = tp_total / (tp_total + fn_total) if (tp_total + fn_total) > 0 else 0.0 - if precision + recall > 0: - return float(2 * precision * recall / (precision + recall)) - return 0.0 + return _f1_from_class_counts(tp, fp, fn, mode) # --------------------------------------------------------------------------- @@ -519,12 +646,9 @@ def event_class_f1(y_true, y_pred, iou_threshold=0.5): EVENT_METRICS = { "map_iou": map_iou, - << << << < HEAD "f1_det": f1_det, - == == == = "event_span_iou": event_span_iou, - "event_class_f1": event_class_f1, - >>>>>> > 0d18cdca10c37ad36a1377ba9308990025a2a078 + "event_iou_f1": event_iou_f1, } ALL_METRICS = { From 4f181b6f8006101fbc7a450e8373140b557c442f Mon Sep 17 00:00:00 2001 From: Tianjun Hou Date: Fri, 29 May 2026 12:26:42 +0200 Subject: [PATCH 08/15] REFACTOR merge event-detection into Chronos solver - Move _get_linear_cosine_scheduler, precompute_embeddings, fit_event_head from solvers/chronos_event.py to benchmark_utils/adapters/event_detection.py - Add event_detection branch to solvers/chronos.py (import, SUPPORTED_TASKS, _CHRONOS_D, parameters, set_objective block, run() elif) - Delete solvers/chronos_event.py Co-authored-by: Cursor --- benchmark_utils/adapters/event_detection.py | 164 ++++++++++ solvers/chronos.py | 74 ++++- solvers/chronos_event.py | 326 -------------------- 3 files changed, 230 insertions(+), 334 deletions(-) delete mode 100644 solvers/chronos_event.py diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py index ca97af0..0c92c82 100644 --- a/benchmark_utils/adapters/event_detection.py +++ b/benchmark_utils/adapters/event_detection.py @@ -274,3 +274,167 @@ def predict(self, x: np.ndarray) -> np.ndarray: pos = torch.sigmoid(pos_logits[0]).cpu().numpy() # (N, 2) cls = torch.sigmoid(cls_logits[0]).cpu().numpy() # (N, k) return np.concatenate([pos, cls], axis=-1).astype(np.float32) # (N, 2+k) + + +# --------------------------------------------------------------------------- +# Training helpers (used by solvers/chronos.py) +# --------------------------------------------------------------------------- + +def _get_linear_cosine_scheduler(optimizer, warmup_epochs, total_epochs): + """Linear warmup + cosine annealing LR scheduler (epoch-level).""" + import math + + def lr_lambda(epoch): + if epoch < warmup_epochs: + return float(epoch + 1) / float(max(1, warmup_epochs)) + progress = float(epoch - warmup_epochs) / float( + max(1, total_epochs - warmup_epochs) + ) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +def precompute_embeddings(pipeline, X_train): + """Embed every training series via a frozen Chronos pipeline. + + Each series is embedded channel-by-channel and the per-channel + (T_tok, D) tensors are mean-pooled to produce a single (T_tok, D) + representation. Results are cached on CPU. + + Parameters + ---------- + pipeline : ChronosPipeline or Chronos2Pipeline + A loaded (and frozen) Chronos pipeline exposing ``.embed()``. + X_train : List[np.ndarray (T, C)] + Training time series. + + Returns + ------- + List[torch.Tensor] — one (T_tok, D) float32 CPU tensor per series. + """ + Z_train = [] + with torch.no_grad(): + for x in X_train: + x = np.asarray(x, dtype=np.float32) + C = x.shape[1] + channel_embs = [] + for c in range(C): + ctx = torch.tensor(x[:, c], dtype=torch.float32) + emb, _ = pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) + channel_embs.append(emb.squeeze(0).float().cpu()) + stacked = torch.stack(channel_embs, dim=0) # (C, T_tok, D) + Z_train.append(stacked.mean(dim=0)) # (T_tok, D) + return Z_train + + +def fit_event_head( + Z_train, + y_train, + n_classes, + d_model, + device, + batch_size=32, + num_epochs=100, + lr=3e-4, + weight_decay=1e-4, + warmup_epochs=5, + num_dec_layers=2, + lambda_cls=1.0, +): + """Train an EventHead on pre-computed Chronos embeddings. + + Parameters + ---------- + Z_train : List[torch.Tensor (T_tok, D)] + Pre-computed encoder embeddings (CPU), one per training series. + y_train : List[np.ndarray (N, 2+k)] + Padded event targets, one per training series. + n_classes : int + Number of binary class columns k. + d_model : int + Encoder hidden dimension D. + device : str + Torch device, e.g. ``"cuda"`` or ``"cpu"``. + batch_size, num_epochs, lr, weight_decay : training hyperparameters. + warmup_epochs : int + Linear warmup duration; cosine decay thereafter. + num_dec_layers : int + Transformer decoder depth. + lambda_cls : float + Weight of the classification loss relative to the position loss. + + Returns + ------- + EventHead — trained, in eval mode, on ``device``. + """ + head = EventHead( + d_model=d_model, + n_classes=n_classes, + num_queries=10, + num_decoder_layers=num_dec_layers, + nhead=8, + ).to(device) + head.train() + + optimizer = torch.optim.AdamW( + head.parameters(), lr=lr, weight_decay=weight_decay + ) + scheduler = _get_linear_cosine_scheduler(optimizer, warmup_epochs, num_epochs) + + N_train = len(Z_train) + use_amp = device == "cuda" + scaler = torch.amp.GradScaler("cuda", enabled=use_amp) + + for epoch in range(num_epochs): + indices = np.random.permutation(N_train) + epoch_loss = 0.0 + num_batches = 0 + + for batch_start in range(0, N_train, batch_size): + batch_idx = indices[batch_start: batch_start + batch_size] + + embs = [Z_train[i] for i in batch_idx] + max_ttok = max(e.shape[0] for e in embs) + D = embs[0].shape[1] + B = len(embs) + + memory = torch.zeros(B, max_ttok, D, dtype=torch.float32) + for bi, e in enumerate(embs): + memory[bi, : e.shape[0]] = e + memory = memory.to(device) + + y_batch = torch.tensor( + np.stack([y_train[i] for i in batch_idx]), + dtype=torch.float32, + device=device, + ) # (B, N, 2+k) + + optimizer.zero_grad() + with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=use_amp): + pos_logits, cls_logits = head(memory) + loss = head.compute_loss( + pos_logits, cls_logits, y_batch, lambda_cls=lambda_cls + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(head.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + + epoch_loss += loss.item() + num_batches += 1 + + scheduler.step() + + if (epoch + 1) % 10 == 0 or epoch == 0: + avg = epoch_loss / max(num_batches, 1) + lr_now = scheduler.get_last_lr()[0] + print( + f" Epoch {epoch + 1:3d}/{num_epochs} | " + f"loss={avg:.4f} | lr={lr_now:.2e}" + ) + + head.eval() + return head diff --git a/solvers/chronos.py b/solvers/chronos.py index 44619b0..e9cd1cb 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -4,8 +4,7 @@ - forecasting : zero-shot via ChronosPipeline - classification : linear probe on pooled encoder embeddings - anomaly_detection : forecast-residual on top of the same forecaster - -Anomaly detection is currently broken and skipped. + - event_detection : EventHead trained on frozen encoder embeddings Model loading is done in ``set_objective`` (untimed). Inference batches every (series, cutoff) pair into a single ``Chronos2Pipeline.predict`` @@ -27,12 +26,20 @@ UnpooledEncoder, ) from benchmark_utils.adapters.base import BaseTSFMAdapter +from benchmark_utils.adapters.event_detection import ( + ChronosEventAdapter, + fit_event_head, + precompute_embeddings, +) from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter from benchmark_utils.base_solver import BaseTSFMSolver from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput -SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection"} +SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection", "event_detection"} + +# Chronos encoder output dimension by model size +_CHRONOS_D = {"tiny": 64, "mini": 128, "small": 512, "base": 768, "large": 1024} POOLERS = { "mean": MeanPooler, @@ -161,10 +168,9 @@ class Solver(BaseTSFMSolver): ``ChronosPipeline.embed`` (post-final-norm). pooler : {"mean", "max", "last"} Pooling strategy over the time-token axis for classification. - task_adaptation : str - Per-task usage of the forecaster: - ``"zeroshot"`` — direct forecasting (forecasting only) - ``"forecast_residual"`` — anomaly score = forecast error (AD only) + model_path : str + Local directory path to load the Chronos model from. When empty + (default), the model is loaded from HuggingFace Hub. """ name = "Chronos" @@ -175,9 +181,31 @@ class Solver(BaseTSFMSolver): "model_size": ["small"], "layer": [None], "pooler": ["mean"], + # event_detection — single values so no cross-product for other tasks + "model_path": [""], + "batch_size": [32], + "num_epochs": [100], + "lr": [3e-4], + "weight_decay": [1e-4], + "warmup_epochs": [5], + "num_dec_layers": [2], + "lambda_cls": [1.0], } - def __init__(self, model_size="small", layer=None, pooler="mean"): + def __init__( + self, + model_size="small", + layer=None, + pooler="mean", + model_path="", + batch_size=32, + num_epochs=100, + lr=3e-4, + weight_decay=1e-4, + warmup_epochs=5, + num_dec_layers=2, + lambda_cls=1.0, + ): """Initialize Chronos-specific state. Parameters @@ -188,11 +216,21 @@ def __init__(self, model_size="small", layer=None, pooler="mean"): Encoder block index for classification embeddings. pooler : {"mean", "max", "last"}, default="mean" Pooling strategy over the time-token axis for classification. + model_path : str, default="" + Local model directory; empty = load from HuggingFace Hub. """ super().__init__( model_size=model_size, layer=layer, pooler=pooler, + model_path=model_path, + batch_size=batch_size, + num_epochs=num_epochs, + lr=lr, + weight_decay=weight_decay, + warmup_epochs=warmup_epochs, + num_dec_layers=num_dec_layers, + lambda_cls=lambda_cls, ) self._pipeline = None self._loaded_model = None @@ -206,6 +244,7 @@ def load_model(self, device, dtype): from chronos import Chronos2Pipeline model_id = f"autogluon/chronos-2-{self.model_size}" + model_id = self.model_path if self.model_path else model_id if not hasattr(self, "_pipeline") or self._loaded_model != model_id: self._pipeline = Chronos2Pipeline.from_pretrained( model_id, @@ -215,6 +254,15 @@ def load_model(self, device, dtype): self._loaded_model = model_id return self._pipeline + def set_objective(self, X_train, y_train, task, **meta): + """Load pipeline then pre-compute embeddings for event_detection.""" + super().set_objective(X_train, y_train, task, **meta) + + if task == "event_detection": + self._n_classes = int(meta["n_classes"]) + self._d_model = _CHRONOS_D.get(self.model_size, 512) + self._Z_train = precompute_embeddings(self.model, X_train) + def forecast_batch(self, inputs): """Chronos-specific batch prediction. @@ -280,4 +328,14 @@ def predict(self, x: ForecastInput) -> ForecastOutput: forecaster = _ForecasterForAD(self, quantile_levels) return ForecastResidualAdapter(forecaster, prediction_length=1) + elif task == "event_detection": + device = "cuda" if torch.cuda.is_available() else "cpu" + head = fit_event_head( + self._Z_train, self.y_train, self._n_classes, self._d_model, + device, self.batch_size, self.num_epochs, self.lr, + self.weight_decay, self.warmup_epochs, self.num_dec_layers, + self.lambda_cls, + ) + return ChronosEventAdapter(model, head, device, self._n_classes) + raise ValueError(f"Unknown task: {task}") diff --git a/solvers/chronos_event.py b/solvers/chronos_event.py deleted file mode 100644 index e72fe77..0000000 --- a/solvers/chronos_event.py +++ /dev/null @@ -1,326 +0,0 @@ -"""Chronos-based event-detection solver for the TSFM benchmark. - -Supports: - - event_detection : frozen Chronos T5 encoder + trainable EventHead - -Overview --------- -Event detection is a structured-prediction task. Each input series x (T, C) -is mapped to a fixed-size set of N=10 span predictions: - - output : (N=10, 2+k) - col 0 : event start, normalised to [0, 1] over T=512 - col 1 : event length, normalised to [0, 1] over T=512 - col 2.. : k binary class probability columns - -y_train / y_test format (padded) ---------------------------------- -Each element is a float32 array of shape (N=10, 2+k). Empty/no-event slots -are represented as all-zero rows. Real event slots satisfy - row[2:].sum() >= 1 -The solver assumes this padded format and reads n_classes from meta -(key "n_classes") and T from meta (key "T", default 512). - -Meta keys expected from the dataset -------------------------------------- - n_classes : int number of binary class columns k - T : int series length (default 512) - -These must be returned by the dataset's get_data() via the ``extra`` dict -that gets forwarded to objective.set_data() and then to get_objective(). - -Architecture (frozen encoder, trained head) --------------------------------------------- -Chronos is purely univariate. To handle C channels we embed each channel -independently via pipeline.embed() and **mean-pool the (T_tok, D) tensors** -across channels. This is parameter-free and works for any C at inference time. -Alternatives (concat, attention aggregation) are noted in event_detection.py. - -Because Chronos is frozen, all training embeddings are pre-computed once in -set_objective() (untimed) and cached in CPU RAM. The timed run() block only -trains the small EventHead — very fast on H100. - -Hyperparameters (H100 defaults) ---------------------------------- - model_size : "small" (D=512, 46 M params) - model_path : "" empty = download from HuggingFace Hub - set to a local directory path to load offline - e.g. "/path/to/models/chronos_t5_small" - batch_size : 32 - num_epochs : 100 - lr : 3e-4 - weight_decay : 1e-4 - warmup_epochs : 5 (linear warmup, cosine decay thereafter) - num_queries : 10 (= N, fixed) - num_dec_layers : 2 - lambda_cls : 1.0 (weight of class loss vs. position loss) - -Benchopt timing contract --------------------------- - set_objective() — model loading + embedding pre-computation (UNTIMED) - run() — EventHead training (TIMED) - get_result() — returns {"model": ChronosEventAdapter} -""" - -import math - -import numpy as np -import torch -from benchopt import BaseSolver - -from benchmark_utils.adapters.event_detection import ChronosEventAdapter, EventHead - - -SUPPORTED_TASKS = {"event_detection"} - -# Chronos encoder output dimensions by model size -_CHRONOS_D = { - "tiny": 64, - "mini": 128, - "small": 512, - "base": 768, - "large": 1024, -} - - -def _get_linear_cosine_scheduler(optimizer, warmup_epochs, total_epochs): - """Linear warmup + cosine annealing LR scheduler (epoch-level).""" - def lr_lambda(epoch): - if epoch < warmup_epochs: - return float(epoch + 1) / float(max(1, warmup_epochs)) - progress = float(epoch - warmup_epochs) / float( - max(1, total_epochs - warmup_epochs) - ) - return 0.5 * (1.0 + math.cos(math.pi * progress)) - - return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - -class Solver(BaseSolver): - """Chronos encoder (frozen) + EventHead event-detection solver. - - See module docstring for full details on the architecture, data format, - and hyperparameter choices. - """ - - name = "Chronos-EventDetection" - - requirements = [ - "pip::chronos-forecasting>=1.4", - "pip::torch", - ] - - sampling_strategy = "run_once" - - parameters = { - "model_size": ["small"], - "model_path": [""], # empty = load from HuggingFace Hub - "batch_size": [32], - "num_epochs": [0, 100], - "lr": [3e-4], - "weight_decay": [1e-4], - "warmup_epochs": [5], - "num_dec_layers": [2], - "lambda_cls": [1.0], - } - - def skip(self, task, **kwargs): - if task not in SUPPORTED_TASKS: - return True, f"Chronos-EventDetection does not support task={task!r}" - return False, None - - # ------------------------------------------------------------------ - # set_objective — UNTIMED - # ------------------------------------------------------------------ - - def set_objective(self, X_train, y_train, task, **meta): - """Load Chronos, freeze it, and pre-compute training embeddings. - - Parameters - ---------- - X_train : List[np.ndarray (T, C)] - y_train : List[np.ndarray (N=10, 2+k)] padded event targets - task : str must be "event_detection" - **meta : must include "n_classes" (int) and optionally "T" (int) - """ - import torch as _torch - from chronos import ChronosPipeline - - self.task = task - self.X_train = X_train - self.y_train = y_train - self.meta = meta - - # --- Infer task dimensions --- - self.n_classes = int(meta["n_classes"]) - self.T = int(meta.get("T", 512)) - self.k = self.n_classes # alias - - # --- Device --- - self.device = "cuda" if _torch.cuda.is_available() else "cpu" - - # --- Resolve model identifier --- - # model_path="" (default) → load from HuggingFace Hub - # model_path="/some/local/dir" → load from that local directory - model_id = ( - self.model_path.strip() - if self.model_path.strip() - else f"amazon/chronos-t5-{self.model_size}" - ) - - # --- Load Chronos pipeline (once; cached across dataset configs) --- - should_reload = ( - not hasattr(self, "_pipeline") - or not hasattr(self, "_loaded_model_id") - or self._loaded_model_id != model_id - ) - if should_reload: - self._pipeline = ChronosPipeline.from_pretrained( - model_id, - device_map="auto", - dtype=_torch.bfloat16, - ) - self._loaded_model_id = model_id - print(f"Loaded Chronos checkpoint: {model_id}") - - # --- Freeze encoder --- - for param in self._pipeline.model.parameters(): - param.requires_grad = False - self._pipeline.model.eval() - - # --- Encoder output dimension --- - self.d_model = _CHRONOS_D.get(self.model_size, 512) - - # --- Pre-compute training embeddings (frozen encoder → one-time) --- - print( - f"Pre-computing embeddings for {len(X_train)} training series " - f"(C={X_train[0].shape[1]}, T={self.T}) ..." - ) - self._Z_train = [] # List of (T_tok, D) CPU tensors - with _torch.no_grad(): - for x in X_train: - emb = self._embed_series(x) # (T_tok, D) float32 on CPU - self._Z_train.append(emb) - - print("Embedding pre-computation complete.") - - # ------------------------------------------------------------------ - # run — TIMED - # ------------------------------------------------------------------ - - def run(self, _): - """Train the EventHead on top of cached Chronos embeddings.""" - import torch as _torch - - head = EventHead( - d_model=self.d_model, - n_classes=self.n_classes, - num_queries=10, - num_decoder_layers=self.num_dec_layers, - nhead=8, - ).to(self.device) - head.train() - - optimizer = _torch.optim.AdamW( - head.parameters(), - lr=self.lr, - weight_decay=self.weight_decay, - ) - scheduler = _get_linear_cosine_scheduler( - optimizer, self.warmup_epochs, self.num_epochs - ) - - N_train = len(self._Z_train) - use_amp = self.device == "cuda" - scaler = _torch.amp.GradScaler("cuda", enabled=use_amp) - - for epoch in range(self.num_epochs): - indices = np.random.permutation(N_train) - epoch_loss = 0.0 - num_batches = 0 - - for batch_start in range(0, N_train, self.batch_size): - batch_idx = indices[batch_start: batch_start + self.batch_size] - - # --- Collate: pad to same T_tok (all same T → same T_tok) --- - embs = [self._Z_train[i] for i in batch_idx] - max_ttok = max(e.shape[0] for e in embs) - D = embs[0].shape[1] - B = len(embs) - - memory = _torch.zeros(B, max_ttok, D, dtype=_torch.float32) - for bi, e in enumerate(embs): - memory[bi, : e.shape[0]] = e - memory = memory.to(self.device) - - # --- Targets --- - y_batch = _torch.tensor( - np.stack([self.y_train[i] for i in batch_idx]), - dtype=_torch.float32, - device=self.device, - ) # (B, N, 2+k) - - # --- Forward + loss --- - optimizer.zero_grad() - with _torch.amp.autocast("cuda", dtype=_torch.bfloat16, - enabled=use_amp): - pos_logits, cls_logits = head(memory) - loss = head.compute_loss( - pos_logits, cls_logits, y_batch, - lambda_cls=self.lambda_cls, - ) - - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - _torch.nn.utils.clip_grad_norm_(head.parameters(), max_norm=1.0) - scaler.step(optimizer) - scaler.update() - - epoch_loss += loss.item() - num_batches += 1 - - scheduler.step() - - if (epoch + 1) % 10 == 0 or epoch == 0: - avg = epoch_loss / max(num_batches, 1) - lr_now = scheduler.get_last_lr()[0] - print( - f" Epoch {epoch + 1:3d}/{self.num_epochs} | " - f"loss={avg:.4f} | lr={lr_now:.2e}" - ) - - head.eval() - self._adapter = ChronosEventAdapter( - pipeline=self._pipeline, - head=head, - device=self.device, - n_classes=self.n_classes, - T=self.T, - ) - - # ------------------------------------------------------------------ - # get_result - # ------------------------------------------------------------------ - - def get_result(self): - return {"model": self._adapter} - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _embed_series(self, x: np.ndarray) -> torch.Tensor: - """Embed a (T, C) array via frozen Chronos, mean-pool across C. - - Returns - ------- - Tensor (T_tok, D) float32 on CPU - """ - C = x.shape[1] - channel_embs = [] - for c in range(C): - ctx = torch.tensor(x[:, c], dtype=torch.float32) - emb, _ = self._pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) - channel_embs.append(emb.squeeze(0).float().cpu()) - stacked = torch.stack(channel_embs, dim=0) # (C, T_tok, D) - return stacked.mean(dim=0) # (T_tok, D) From 30b3f917ad761f53b8536000e667e2b937dec975 Mon Sep 17 00:00:00 2001 From: Tianjun Hou Date: Fri, 29 May 2026 12:37:24 +0200 Subject: [PATCH 09/15] ENH make num_queries configurable; pad/truncate event labels - Add num_queries parameter to Solver.parameters, __init__, and fit_event_head call in chronos.py - Add _pad_or_truncate_labels helper in event_detection.py that truncates or zero-pads label arrays to exactly num_queries rows - Apply padding/truncation upfront in fit_event_head so variable-length label lists can be batched safely Co-authored-by: Cursor --- benchmark_utils/adapters/event_detection.py | 42 +++++++++++++++++++-- solvers/chronos.py | 5 ++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py index 0c92c82..2b35d7e 100644 --- a/benchmark_utils/adapters/event_detection.py +++ b/benchmark_utils/adapters/event_detection.py @@ -328,6 +328,34 @@ def precompute_embeddings(pipeline, X_train): return Z_train +def _pad_or_truncate_labels(y: np.ndarray, num_queries: int) -> np.ndarray: + """Pad or truncate a label array to exactly ``num_queries`` rows. + + Parameters + ---------- + y : (N, 2+k) float array + Raw event labels for one series. ``N`` may be smaller or larger + than ``num_queries``. + num_queries : int + Target number of event slots. + + Returns + ------- + (num_queries, 2+k) float32 array + Rows beyond the original ``N`` are filled with zeros (no-event). + Rows beyond ``num_queries`` in the original are discarded. + """ + y = np.asarray(y, dtype=np.float32) + N, width = y.shape + if N == num_queries: + return y + if N > num_queries: + return y[:num_queries] + # N < num_queries — pad with zero rows + pad = np.zeros((num_queries - N, width), dtype=np.float32) + return np.concatenate([y, pad], axis=0) + + def fit_event_head( Z_train, y_train, @@ -341,6 +369,7 @@ def fit_event_head( warmup_epochs=5, num_dec_layers=2, lambda_cls=1.0, + num_queries=10, ): """Train an EventHead on pre-computed Chronos embeddings. @@ -348,8 +377,10 @@ def fit_event_head( ---------- Z_train : List[torch.Tensor (T_tok, D)] Pre-computed encoder embeddings (CPU), one per training series. - y_train : List[np.ndarray (N, 2+k)] - Padded event targets, one per training series. + y_train : List[np.ndarray (N_i, 2+k)] + Event targets, one per training series. ``N_i`` may differ across + series and need not equal ``num_queries``; labels are automatically + padded (with zeros) or truncated to ``num_queries`` rows. n_classes : int Number of binary class columns k. d_model : int @@ -363,15 +394,20 @@ def fit_event_head( Transformer decoder depth. lambda_cls : float Weight of the classification loss relative to the position loss. + num_queries : int + Number of event slots (decoder queries); default 10. Returns ------- EventHead — trained, in eval mode, on ``device``. """ + # Normalise all labels to (num_queries, 2+k) once, before the training loop + y_train = [_pad_or_truncate_labels(y, num_queries) for y in y_train] + head = EventHead( d_model=d_model, n_classes=n_classes, - num_queries=10, + num_queries=num_queries, num_decoder_layers=num_dec_layers, nhead=8, ).to(device) diff --git a/solvers/chronos.py b/solvers/chronos.py index e9cd1cb..9d59513 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -190,6 +190,7 @@ class Solver(BaseTSFMSolver): "warmup_epochs": [5], "num_dec_layers": [2], "lambda_cls": [1.0], + "num_queries": [10], } def __init__( @@ -205,6 +206,7 @@ def __init__( warmup_epochs=5, num_dec_layers=2, lambda_cls=1.0, + num_queries=10, ): """Initialize Chronos-specific state. @@ -231,6 +233,7 @@ def __init__( warmup_epochs=warmup_epochs, num_dec_layers=num_dec_layers, lambda_cls=lambda_cls, + num_queries=num_queries, ) self._pipeline = None self._loaded_model = None @@ -334,7 +337,7 @@ def predict(self, x: ForecastInput) -> ForecastOutput: self._Z_train, self.y_train, self._n_classes, self._d_model, device, self.batch_size, self.num_epochs, self.lr, self.weight_decay, self.warmup_epochs, self.num_dec_layers, - self.lambda_cls, + self.lambda_cls, self.num_queries, ) return ChronosEventAdapter(model, head, device, self._n_classes) From be07b94bfa0762a2ecc1deac7ab024f5f8381402 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 14:57:38 +0200 Subject: [PATCH 10/15] metrics for event detection + adapt chronos fro event detection regression --- datasets/mitdb.py | 4 +- solvers/chronos.py | 278 +++++++++--------- tests/benchmark_utils/test_metrics_map_iou.py | 212 ++++++++++++- 3 files changed, 347 insertions(+), 147 deletions(-) diff --git a/datasets/mitdb.py b/datasets/mitdb.py index 61e9ccc..1083b5b 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -21,7 +21,7 @@ class 4 Q — Unknown / pacemaker artefact X_test : List[np.ndarray (T_j, 2)] test portions y_test : List[np.ndarray (N, 2+K)] same format, same N task : "event_detection" -metrics : ["map_iou"] +metrics : ["event_span_iou", "event_iou_f1"] extra : n_classes (int) K above """ @@ -204,6 +204,6 @@ def _pad(arrays): X_test=X_test, y_test=_pad(y_test), task="event_detection", - metrics=["map_iou"], + metrics=["event_span_iou", "event_iou_f1"], n_classes=self.n_classes, ) diff --git a/solvers/chronos.py b/solvers/chronos.py index 9d59513..d16f72e 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -7,7 +7,7 @@ - event_detection : EventHead trained on frozen encoder embeddings Model loading is done in ``set_objective`` (untimed). Inference batches -every (series, cutoff) pair into a single ``Chronos2Pipeline.predict`` +every (series, cutoff) pair into a single ``ChronosPipeline.predict`` call — the pipeline accepts a list of variable-length tensors and applies left-padding internally, so all the per-cutoff work happens in one forward pass. @@ -15,6 +15,7 @@ import numpy as np import torch +from benchopt import BaseSolver from chronos import ChronosPipeline from benchmark_utils.adapters import ( @@ -32,7 +33,6 @@ precompute_embeddings, ) from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter -from benchmark_utils.base_solver import BaseTSFMSolver from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput @@ -48,31 +48,94 @@ } -def _to_context(x): - """Reshape ``(T, V)`` or ``(B, T, V)`` to Chronos input ``(B, V, T)``.""" - x = np.asarray(x, dtype=np.float32) - if x.ndim == 2: - x = x[None] - return x.transpose(0, 2, 1) +class _ChronosForecaster(BaseTSFMAdapter): + """Batched Chronos v1 adapter; quantiles are derived from sample draws.""" + + DEFAULT_QUANTILE_LEVELS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + def __init__(self, pipeline, prediction_length, quantile_levels=None): + self.pipeline = pipeline + self.prediction_length = prediction_length + self.quantile_levels = quantile_levels or self.DEFAULT_QUANTILE_LEVELS + + # ------------------------------------------------------------------ + # Template method — subclasses override _build_inputs / _assemble + # ------------------------------------------------------------------ + + def predict(self, x: ForecastInput) -> ForecastOutput: + inputs, layout, per_series_shape = self._build_inputs(x) + if not inputs: + return ForecastOutput(quantiles=[], quantile_levels=self.quantile_levels) + + with torch.no_grad(): + output = self.pipeline.predict( + inputs, + prediction_length=self.prediction_length, + ) + return self._assemble_output(output, layout, per_series_shape) + + def _build_inputs(self, x): + """Build list of 1-D tensors (one per channel) and track layout.""" + inputs = [] + layout = [] # (series_idx, cutoff_idx, channel_idx) + per_series_shape = [] # (C, n_cutoffs) + for series_idx, (series, cutoffs) in enumerate(zip(x.x, x.cutoff_indexes)): + series = np.asarray(series, dtype=np.float32) + if series.ndim == 1: + series = series[:, None] + _, C = series.shape + per_series_shape.append((C, len(cutoffs))) + for cutoff_idx, cutoff in enumerate(cutoffs): + hist = series[:cutoff] + for c in range(C): + inputs.append(torch.from_numpy(hist[:, c])) + layout.append((series_idx, cutoff_idx, c)) + return inputs, layout, per_series_shape + + def _assemble_output(self, samples, layout, per_series_shape): + """Derive quantile fan from Monte-Carlo sample draws.""" + # samples: (n_inputs, num_samples, H) + q_arr = np.quantile( + samples.float().cpu().numpy(), + q=list(self.quantile_levels), + axis=1, + ).transpose(1, 0, 2) # (n_inputs, Q, H) + + Q = len(self.quantile_levels) + per_series = [ + np.empty((n_cutoffs, Q, self.prediction_length, C), dtype=np.float32) + for C, n_cutoffs in per_series_shape + ] + for i, (series_idx, cutoff_idx, c) in enumerate(layout): + per_series[series_idx][cutoff_idx, :, :, c] = q_arr[i] + + return ForecastOutput(quantiles=per_series, quantile_levels=self.quantile_levels) class _ChronosEmbedEncoder(UnpooledEncoder): - """Default path — uses ``Chronos2Pipeline.embed``. + """Default path — uses ``ChronosPipeline.embed``. - Returns hidden states *after* ``encoder.final_layer_norm`` for each - series in the batch. + Returns hidden states *after* ``encoder.final_layer_norm``. """ def __init__(self, pipeline): self.pipeline = pipeline def encode(self, X) -> np.ndarray: - context = _to_context(X) # (B, V, T) + # X: (B, T, V) or (T, V). + X = np.asarray(X, dtype=np.float32) + batched = X.ndim == 3 + if not batched: + X = X[None] # (1, T, V) + B, T, V = X.shape + + # Chronos is univariate — flatten B & V into the batch axis. + flat = X.reshape(B * V, T) # (B*V, T) with torch.no_grad(): - # embed returns a list of B tensors, each of shape (V, T, D). - embeddings, _ = self.pipeline.embed(context) - stacked = torch.stack(list(embeddings)) # (B, V, T, D) - return stacked.transpose(1, 2).float().cpu().numpy() # (B, T, V, D) + emb, _ = self.pipeline.embed(torch.from_numpy(flat)) # (B*V, T_tok, D) + + # (B*V, T_tok, D) -> (B, T_tok, V, D) + return emb.float().cpu().numpy().reshape(B, -1, V, emb.shape[-1]) class _ChronosHookEncoder(UnpooledEncoder): @@ -91,8 +154,15 @@ def __init__(self, pipeline, layer: int): ) self._block_idx = layer % n_blocks - def encode(self, x: np.ndarray) -> np.ndarray: - context = _to_context(x) # (B, V, T) + def encode(self, X) -> np.ndarray: + X = np.asarray(X, dtype=np.float32) + batched = X.ndim == 3 + if not batched: + X = X[None] # (1, T, V) + B, T, V = X.shape + + flat = X.reshape(B * V, T) # (B*V, T) + context = torch.from_numpy(flat) token_ids, attn_mask, _ = self.pipeline.tokenizer.context_input_transform( context ) @@ -104,8 +174,6 @@ def encode(self, x: np.ndarray) -> np.ndarray: captured = {} def _hook(_module, _inputs, output): - # Hook to capture the embeddings while performing a forward pass - # T5Block returns a tuple; first element is the hidden state. hidden = output[0] if isinstance(output, tuple) else output captured["h"] = hidden.detach() @@ -116,8 +184,8 @@ def _hook(_module, _inputs, output): finally: handle.remove() - # (C, T_tok, D) -> (T_tok, C, D) - return captured["h"].transpose(0, 1).float().cpu().numpy() + # (B*V, T_tok, D) -> (B, T_tok, V, D) + return captured["h"].float().cpu().numpy().reshape(B, -1, V, captured["h"].shape[-1]) def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: @@ -156,8 +224,8 @@ def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: # --------------------------------------------------------------------------- -class Solver(BaseTSFMSolver): - """Chronos-2 zero-shot solver. +class Solver(BaseSolver): + """Chronos zero-shot solver. Parameters ---------- @@ -175,7 +243,9 @@ class Solver(BaseTSFMSolver): name = "Chronos" - requirements = ["pip::chronos-forecasting>=2.2,<3"] + requirements = ["pip::chronos-forecasting>=2.2", "pip::torch"] + + sampling_strategy = "run_once" parameters = { "model_size": ["small"], @@ -193,118 +263,42 @@ class Solver(BaseTSFMSolver): "num_queries": [10], } - def __init__( - self, - model_size="small", - layer=None, - pooler="mean", - model_path="", - batch_size=32, - num_epochs=100, - lr=3e-4, - weight_decay=1e-4, - warmup_epochs=5, - num_dec_layers=2, - lambda_cls=1.0, - num_queries=10, - ): - """Initialize Chronos-specific state. - - Parameters - ---------- - model_size : str, default="small" - Chronos model variant to load. - layer : int or None, default=None - Encoder block index for classification embeddings. - pooler : {"mean", "max", "last"}, default="mean" - Pooling strategy over the time-token axis for classification. - model_path : str, default="" - Local model directory; empty = load from HuggingFace Hub. - """ - super().__init__( - model_size=model_size, - layer=layer, - pooler=pooler, - model_path=model_path, - batch_size=batch_size, - num_epochs=num_epochs, - lr=lr, - weight_decay=weight_decay, - warmup_epochs=warmup_epochs, - num_dec_layers=num_dec_layers, - lambda_cls=lambda_cls, - num_queries=num_queries, - ) - self._pipeline = None - self._loaded_model = None - - @property - def supported_tasks(self): - return SUPPORTED_TASKS - - def load_model(self, device, dtype): - """Load Chronos-2 pipeline (cached if already loaded).""" - from chronos import Chronos2Pipeline + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"Chronos solver does not support task={task!r}" + return False, None - model_id = f"autogluon/chronos-2-{self.model_size}" - model_id = self.model_path if self.model_path else model_id + def set_objective(self, X_train, y_train, task, **meta): + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + # bfloat16 is fine on CUDA but poorly supported on CPU / MPS; + # fall back to float32 there so inference doesn't crash or stall. + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 if device == "cuda" else torch.float32 + model_id = self.model_path if self.model_path else f"amazon/chronos-t5-{self.model_size}" if not hasattr(self, "_pipeline") or self._loaded_model != model_id: - self._pipeline = Chronos2Pipeline.from_pretrained( + self._pipeline = ChronosPipeline.from_pretrained( model_id, device_map=device, dtype=dtype, ) self._loaded_model = model_id - return self._pipeline - - def set_objective(self, X_train, y_train, task, **meta): - """Load pipeline then pre-compute embeddings for event_detection.""" - super().set_objective(X_train, y_train, task, **meta) if task == "event_detection": - self._n_classes = int(meta["n_classes"]) - self._d_model = _CHRONOS_D.get(self.model_size, 512) - self._Z_train = precompute_embeddings(self.model, X_train) - - def forecast_batch(self, inputs): - """Chronos-specific batch prediction. - - Parameters - ---------- - inputs : list of torch.Tensor - Each tensor shape (C, T_cutoff) - - Returns - ------- - list of torch.Tensor - Each tensor shape (C, Q, H) - """ - with torch.no_grad(): - return self.model.predict(inputs, prediction_length=self.prediction_length) + self._Z_train = precompute_embeddings(self._pipeline, X_train) + self._d_model = self._Z_train[0].shape[-1] + self._n_classes = np.asarray(y_train[0]).shape[-1] - 2 - def build_adapter(self, task, model): - # TODO later: put that code in base_solver.py - # and make it rely on .forecast(), .embed() and .time_embed() only, once those are all properly coded - """Create task-specific adapter for Chronos.""" + def run(self, _): pred_len = self.meta.get("prediction_length", 1) + if self.task == "forecasting": + self._adapter = _ChronosForecaster(self._pipeline, pred_len) - if task == "forecasting": - self.prediction_length = pred_len - quantile_levels = tuple(float(q) for q in model.quantiles) - - # Create a simple adapter that calls self.forecast() - class _ForecastAdapter(BaseTSFMAdapter): - def __init__(self, solver, quantile_levels): - self.solver = solver - self.quantile_levels = quantile_levels - - def predict(self, x: ForecastInput) -> ForecastOutput: - return self.solver.forecast(x, self.solver.prediction_length, self.quantile_levels) - - return _ForecastAdapter(self, quantile_levels) - - elif task == "classification": - base_encoder = ChronosEncoder(model, layer=self.layer) + elif self.task == "classification": + base_encoder = ChronosEncoder(self._pipeline, layer=self.layer) encoder = Encoder(base_encoder, POOLERS[self.pooler]()) adapter = LinearProbeAdapter( encoder, @@ -312,26 +306,16 @@ def predict(self, x: ForecastInput) -> ForecastOutput: n_classes=self.meta.get("n_classes"), ) adapter.fit(self.X_train, self.y_train) - return adapter + self._adapter = adapter - elif task == "anomaly_detection": + elif self.task == "anomaly_detection": # AD uses one-step-ahead forecasts. - self.prediction_length = 1 - quantile_levels = (0.5,) - - # Create a forecaster adapter for residual-based anomaly detection - class _ForecasterForAD(BaseTSFMAdapter): - def __init__(self, solver, quantile_levels): - self.solver = solver - self.quantile_levels = quantile_levels - - def predict(self, x: ForecastInput) -> ForecastOutput: - return self.solver.forecast(x, 1, self.quantile_levels) - - forecaster = _ForecasterForAD(self, quantile_levels) - return ForecastResidualAdapter(forecaster, prediction_length=1) + self._adapter = ForecastResidualAdapter( + _ChronosForecaster(self._pipeline, prediction_length=1), + prediction_length=1, + ) - elif task == "event_detection": + elif self.task == "event_detection": device = "cuda" if torch.cuda.is_available() else "cpu" head = fit_event_head( self._Z_train, self.y_train, self._n_classes, self._d_model, @@ -339,6 +323,12 @@ def predict(self, x: ForecastInput) -> ForecastOutput: self.weight_decay, self.warmup_epochs, self.num_dec_layers, self.lambda_cls, self.num_queries, ) - return ChronosEventAdapter(model, head, device, self._n_classes) + self._adapter = ChronosEventAdapter( + self._pipeline, head, device, self._n_classes + ) + + else: + raise ValueError(f"Unknown task: {self.task}") - raise ValueError(f"Unknown task: {task}") + def get_result(self): + return {"model": self._adapter} diff --git a/tests/benchmark_utils/test_metrics_map_iou.py b/tests/benchmark_utils/test_metrics_map_iou.py index 4ad55a7..e7d5317 100644 --- a/tests/benchmark_utils/test_metrics_map_iou.py +++ b/tests/benchmark_utils/test_metrics_map_iou.py @@ -5,7 +5,16 @@ import numpy as np import pytest -from benchmark_utils.metrics import _ap_from_tp_fp, _iou_1d, f1_det, map_iou +from benchmark_utils.metrics import ( + _ap_from_tp_fp, + _f1_from_class_counts, + _iou_1d, + _match_spans, + event_iou_f1, + event_span_iou, + f1_det, + map_iou, +) def _evt(start, width, *class_cols): @@ -129,6 +138,207 @@ def test_map_prediction_in_wrong_series_is_fp(): assert map_iou(y_true, y_pred) == pytest.approx(0.0) +# --------------------------------------------------------------------------- +# Helpers for padded-array format used by event_span_iou / event_iou_f1 +# --------------------------------------------------------------------------- + +def _series(events, n_slots=5, n_classes=2): + """Padded (n_slots, 2+n_classes) array from list of (start, width, *cols).""" + arr = np.zeros((n_slots, 2 + n_classes)) + for i, ev in enumerate(events): + arr[i] = ev + return arr + + +# --------------------------------------------------------------------------- +# _match_spans +# --------------------------------------------------------------------------- + +def test_match_spans_empty_gt(): + gt = np.zeros((0, 4)) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_empty_pred(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.zeros((0, 4)) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_perfect_match_greedy(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [(0, 0)] + +def test_match_spans_perfect_match_hungarian(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "hungarian") == [(0, 0)] + +def test_match_spans_below_threshold_no_match(): + # IoU([0,1), [0.5,1.5)) ≈ 0.33 < 0.5 + gt = np.array([[0.0, 1.0, 1.0, 0.0]]) + pr = np.array([[0.5, 1.0, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_greedy_vs_hungarian_differ(): + # pred0: IoU=1.0 with GT, score=0.6 + # pred1: IoU=0.8 with GT, score=0.9 + # Greedy: pred1 (higher score) goes first → matches GT → [(0, 1)] + # Hungarian: pred0 (higher IoU) is assigned → [(0, 0)] + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([ + [0.0, 0.5, 0.6, 0.0], # pred0: IoU=1.0, score=0.6 + [0.1, 0.4, 0.9, 0.0], # pred1: IoU=0.8, score=0.9 + ]) + assert _match_spans(gt, pr, 0.5, "greedy") == [(0, 1)] + assert _match_spans(gt, pr, 0.5, "hungarian") == [(0, 0)] + +def test_match_spans_duplicate_pred_only_one_matched(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0], [0.0, 0.5, 0.8, 0.0]]) + assert len(_match_spans(gt, pr, 0.5, "greedy")) == 1 + +def test_match_spans_two_gt_two_pred_both_matched(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pr = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert set(_match_spans(gt, pr, 0.5, "greedy")) == {(0, 0), (1, 1)} + +def test_match_spans_invalid_strategy_raises(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + with pytest.raises(ValueError): + _match_spans(gt, pr, 0.5, "invalid") + + +# --------------------------------------------------------------------------- +# _f1_from_class_counts +# --------------------------------------------------------------------------- + +def test_f1_counts_micro_perfect(): + assert _f1_from_class_counts(np.array([2, 1]), np.array([0, 0]), np.array([0, 0]), "micro") == pytest.approx(1.0) + +def test_f1_counts_micro_all_zeros(): + assert _f1_from_class_counts(np.array([0, 0]), np.array([0, 0]), np.array([0, 0]), "micro") == pytest.approx(0.0) + +def test_f1_counts_micro_mixed(): + # tp=1, fp=1, fn=1 → P=0.5, R=0.5 → F1=0.5 + assert _f1_from_class_counts(np.array([1]), np.array([1]), np.array([1]), "micro") == pytest.approx(0.5) + +def test_f1_counts_macro_mixed(): + # class 0: tp=1, fp=0, fn=0 → F1=1.0; class 1: tp=0, fp=1, fn=1 → F1=0 → mean=0.5 + assert _f1_from_class_counts(np.array([1, 0]), np.array([0, 1]), np.array([0, 1]), "macro") == pytest.approx(0.5) + +def test_f1_counts_micro_macro_differ(): + # class 0: tp=1,fp=0,fn=0 → F1=1.0 | class 1: tp=0,fp=0,fn=1 → F1=0 + # micro: tp_s=1,fp_s=0,fn_s=1 → P=1,R=0.5 → F1=2/3 + # macro: mean([1.0, 0.0]) = 0.5 + tp = np.array([1, 0]) + fp = np.array([0, 0]) + fn = np.array([0, 1]) + assert _f1_from_class_counts(tp, fp, fn, "micro") == pytest.approx(2 / 3) + assert _f1_from_class_counts(tp, fp, fn, "macro") == pytest.approx(0.5) + +def test_f1_counts_invalid_mode_raises(): + with pytest.raises(ValueError): + _f1_from_class_counts(np.array([1]), np.array([0]), np.array([0]), "invalid") + + +# --------------------------------------------------------------------------- +# event_span_iou +# --------------------------------------------------------------------------- + +def test_event_span_iou_both_empty(): + s = _series([]) + assert event_span_iou([s], [s]) == pytest.approx(1.0) + +def test_event_span_iou_pred_empty(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_gt_empty(): + gt = _series([]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_perfect_match(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(1.0) + +def test_event_span_iou_no_overlap(): + gt = _series([[0.0, 0.3, 1, 0]]) + pr = _series([[0.7, 0.3, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_pred_below_score_threshold_ignored(): + # max class score 0.3 < 0.5 → pred filtered out → G=1, P=0 → F1=0 + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.3, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_multi_series_averaged(): + # series 0: perfect → F1=1.0; series 1: no pred → F1=0.0; mean=0.5 + gt0, pr0 = _series([[0.0, 0.5, 1, 0]]), _series([[0.0, 0.5, 0.9, 0.0]]) + gt1, pr1 = _series([[0.2, 0.4, 1, 0]]), _series([]) + assert event_span_iou([gt0, gt1], [pr0, pr1]) == pytest.approx(0.5) + + +# --------------------------------------------------------------------------- +# event_iou_f1 +# --------------------------------------------------------------------------- + +def test_event_iou_f1_no_arrays_returns_nan(): + assert math.isnan(event_iou_f1([], [])) + +def test_event_iou_f1_perfect_match(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.1]]) + assert event_iou_f1([gt], [pr]) == pytest.approx(1.0) + +def test_event_iou_f1_unmatched_gt_is_fn(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([]) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_unmatched_pred_is_fp(): + gt = _series([]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_wrong_class_prediction(): + # Span matched by IoU but class assignment is swapped → class errors + # tp=[0,0], fp=[1,0], fn=[0,1] + # micro: tp_s=0, fp_s=1, fn_s=1 → F1=0 + gt = _series([[0.0, 0.5, 1, 0]]) # class 0 active + pr = _series([[0.0, 0.5, 0.4, 0.9]]) # predicts class 1 (max=0.9 > 0.5) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_micro_vs_macro_differ(): + # GT0=class0, GT1=class1; Pred0 correct, Pred1 wrong class + # After matching: (GT0,P0) tp=[1,0]; (GT1,P1) gt_cls=[0,1] pr_cls=[1,0] + # → tp+=[0,0], fp+=[1,0], fn+=[0,1] + # total: tp=[1,0], fp=[1,0], fn=[0,1] + # micro: tp_s=1,fp_s=1,fn_s=1 → P=0.5,R=0.5 → F1=0.5 + # macro: class0 P=0.5,R=1→F1=2/3; class1 P=0,R=0→F1=0 → mean=1/3 + gt = _series([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pr = _series([[0.0, 0.3, 0.9, 0.1], [0.5, 0.3, 0.6, 0.4]]) + assert event_iou_f1([gt], [pr], mode="micro") == pytest.approx(0.5) + assert event_iou_f1([gt], [pr], mode="macro") == pytest.approx(1 / 3) + +def test_event_iou_f1_hungarian_strategy(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.1]]) + assert event_iou_f1([gt], [pr], matching_strategy="hungarian") == pytest.approx(1.0) + +def test_event_iou_f1_multi_series(): + gt0 = _series([[0.0, 0.5, 1, 0]]) + pr0 = _series([[0.0, 0.5, 0.9, 0.0]]) + gt1 = _series([[0.2, 0.4, 0, 1]]) + pr1 = _series([[0.2, 0.4, 0.1, 0.8]]) + assert event_iou_f1([gt0, gt1], [pr0, pr1]) == pytest.approx(1.0) + + # --------------------------------------------------------------------------- # f1_det # --------------------------------------------------------------------------- From a3864b691918c7c9e9c52407d7ca864f687218cd Mon Sep 17 00:00:00 2001 From: Tianjun Hou Date: Fri, 29 May 2026 15:40:44 +0200 Subject: [PATCH 11/15] feat: Hungarian matching loss + MITDB sliding-window windowing EventHead.compute_loss: - Replace naive slot-aligned loss with DETR-style Hungarian matching - Build (N, M) cost matrix per sample: L1 position cost + class cost - Solve optimal assignment via scipy.optimize.linear_sum_assignment - Matched slots: smooth_l1 position loss + GT class target - Unmatched slots: BCE toward no-event (all-zero) class target MITDB dataset: - Add window_size=512 and window_overlap=0.5 parameters - Add _extract_window_events helper: includes only beat events whose full span (start, length) = (R-peak - beat_window, 2*beat_window) lies entirely within the window; boundary-straddling events are discarded - get_data now emits one (512, 2) window per stride step instead of one element per full record segment - Positions normalised to [0, 1] over window_size Co-authored-by: Cursor --- benchmark_utils/adapters/event_detection.py | 111 ++++++++++++---- datasets/mitdb.py | 137 +++++++++++++++++--- 2 files changed, 202 insertions(+), 46 deletions(-) diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py index 2b35d7e..3f6fd56 100644 --- a/benchmark_utils/adapters/event_detection.py +++ b/benchmark_utils/adapters/event_detection.py @@ -28,15 +28,24 @@ col 1 : length (normalised to [0,1] over T=512) col 2..2+k : binary multi-class one-hot columns (sum >= 1 for real events) -Loss ----- -For each of the 10 query slots: - * has_event mask = (y_cls.sum(-1) > 0) shape (B, N) - * position loss : smooth_l1 on (start, length), applied only where has_event - * class loss : BCEWithLogitsLoss on all k columns, applied to all slots - (empty slots drive logits toward 0, which is the no-event baseline) - -Combined: loss = pos_loss + lambda_cls * cls_loss (lambda_cls=1.0 by default) +Loss — DETR-style Hungarian matching +------------------------------------- +Training uses per-sample optimal bipartite assignment (Hungarian algorithm) +between the N predicted slots and the M real ground-truth events (M <= N): + + 1. Build cost matrix (N x M) per sample: + cost_pos[n,m] = L1(sigmoid(pos_logits[n]), gt_pos[m]) (sum over 2 dims) + cost_cls[n,m] = -sum_k sigmoid(cls_logits[n,k]) * gt_cls[m,k] + cost[n,m] = cost_pos[n,m] + lambda_cls * cost_cls[n,m] + 2. Solve: scipy.optimize.linear_sum_assignment(cost) -> (pred_idx, gt_idx) + 3. Position loss: smooth_l1 on matched (pred, gt) span pairs. + 4. Class loss: BCEWithLogitsLoss where + - matched slots -> their assigned GT class vector + - unmatched slots -> all-zero target (no-event) + 5. Average losses over batch. + +This gives the model permutation invariance: the N queries can predict events +in any order and the loss always finds the optimal pairing. """ import numpy as np @@ -159,36 +168,86 @@ def compute_loss( y: torch.Tensor, lambda_cls: float = 1.0, ) -> torch.Tensor: - """Compute combined position + classification loss. + """Compute combined position + classification loss with Hungarian matching. + + For each sample in the batch the optimal bipartite assignment between + the N predicted slots and the M real ground-truth events is found via + the Hungarian algorithm. Matched slots are trained on their assigned + GT event; unmatched slots are trained toward the no-event class target + (all-zero class vector). Parameters ---------- pos_logits : (B, N, 2) — raw span predictions (sigmoid applied inside) cls_logits : (B, N, k) — raw class logits - y : (B, N, 2+k) — ground-truth targets, float32 + y : (B, N, 2+k) — ground-truth targets, float32; + all-zero rows are empty / no-event slots lambda_cls : float — weight for the class loss term Returns ------- scalar loss tensor """ - y_pos = y[..., :2] # (B, N, 2) start, length - y_cls = y[..., 2:] # (B, N, k) binary class targets - - # Mask: only penalise position loss on slots that have a real event - has_event = (y_cls.sum(dim=-1) > 0).float() # (B, N) - - pos_pred = torch.sigmoid(pos_logits) # (B, N, 2) in [0,1] - pos_loss_per = nn.functional.smooth_l1_loss( - pos_pred, y_pos, reduction="none" - ).mean(dim=-1) # (B, N) - pos_loss = (pos_loss_per * has_event).sum() / (has_event.sum() + 1e-6) + from scipy.optimize import linear_sum_assignment + + B, N, _ = pos_logits.shape + y_pos = y[..., :2] # (B, N, 2) + y_cls = y[..., 2:] # (B, N, k) + + pos_pred = torch.sigmoid(pos_logits) # (B, N, 2) in [0,1] + cls_prob = torch.sigmoid(cls_logits) # (B, N, k) in [0,1] + + total_pos_loss = pos_logits.new_zeros(()) + total_cls_loss = pos_logits.new_zeros(()) + + for b in range(B): + # Identify real GT events for this sample + has_event_b = y_cls[b].sum(dim=-1) > 0 # (N,) bool mask over GT slots + gt_pos = y_pos[b][has_event_b] # (M, 2) + gt_cls = y_cls[b][has_event_b] # (M, k) + M = gt_pos.shape[0] + + pred_pos_b = pos_pred[b] # (N, 2) + cls_logits_b = cls_logits[b] # (N, k) + cls_prob_b = cls_prob[b] # (N, k) — used for cost matrix only + + if M == 0: + # No GT events: drive all slots to no-event (zero target) + total_cls_loss = total_cls_loss + nn.functional.binary_cross_entropy_with_logits( + cls_logits_b, + torch.zeros_like(cls_logits_b), + reduction="mean", + ) + continue + + # --- cost matrix (N x M) ---------------------------------------- + # L1 position cost: sum of |pred - gt| over the 2 span dimensions + with torch.no_grad(): + cost_pos = torch.cdist(pred_pos_b, gt_pos, p=1) # (N, M) + # Class cost: negative dot product of sigmoid probabilities and + # GT class vectors — lower cost means better class agreement. + cost_cls = -(cls_prob_b @ gt_cls.T) # (N, M) + cost = cost_pos + lambda_cls * cost_cls # (N, M) + + pred_idx, gt_idx = linear_sum_assignment(cost.cpu().numpy()) + + # --- position loss on matched pairs -------------------------------- + matched_pred_pos = pred_pos_b[pred_idx] # (M, 2) + matched_gt_pos = gt_pos[gt_idx] # (M, 2) + pos_loss_b = nn.functional.smooth_l1_loss( + matched_pred_pos, matched_gt_pos, reduction="mean" + ) + total_pos_loss = total_pos_loss + pos_loss_b - cls_loss = nn.functional.binary_cross_entropy_with_logits( - cls_logits, y_cls, reduction="mean" - ) + # --- class loss: matched → GT class, unmatched → zero target ------- + cls_target = torch.zeros_like(cls_logits_b) # (N, k) all zeros + cls_target[pred_idx] = gt_cls[gt_idx] # matched slots get real labels + cls_loss_b = nn.functional.binary_cross_entropy_with_logits( + cls_logits_b, cls_target, reduction="mean" + ) + total_cls_loss = total_cls_loss + cls_loss_b - return pos_loss + lambda_cls * cls_loss + return (total_pos_loss + lambda_cls * total_cls_loss) / B # --------------------------------------------------------------------------- diff --git a/datasets/mitdb.py b/datasets/mitdb.py index 1083b5b..d761755 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -9,19 +9,28 @@ class 2 V — Ventricular ectopic class 3 F — Fusion class 4 Q — Unknown / pacemaker artefact +Windowing +--------- +Each record is split into train/test portions and then sliced into fixed-size +overlapping windows (default: window_size=512 samples, overlap=50% → stride 256). +Only beat events whose full span [R-peak − beat_window, R-peak + beat_window) +lies **entirely within** the window are included. Events that straddle a +window boundary are discarded. Positions are normalised to [0, 1] over the +window length. + Data contract output -------------------- -X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) -y_train : List[np.ndarray (N, 2+K)] one row per beat event, zero-padded - to N = max events across all segments: - col 0 start (normalised, 0–1) - col 1 width (normalised, 0–1) - cols 2… one-hot class vector (K) - all-zero row = empty/padding slot -X_test : List[np.ndarray (T_j, 2)] test portions -y_test : List[np.ndarray (N, 2+K)] same format, same N +X_train : List[np.ndarray (512, 2)] one window per element (C == 2) +y_train : List[np.ndarray (N, 2+K)] one row per beat event, zero-padded + to N = max events across all windows: + col 0 start (normalised, 0–1) + col 1 width (normalised, 0–1) + cols 2… one-hot class vector (K) + all-zero row = empty/padding slot +X_test : List[np.ndarray (512, 2)] test windows +y_test : List[np.ndarray (N, 2+K)] same format, same N task : "event_detection" -metrics : ["event_span_iou", "event_iou_f1"] +metrics : ["map_iou"] extra : n_classes (int) K above """ @@ -118,9 +127,76 @@ def _annotations_to_events(n_samples, ann_samples, ann_symbols, beat_window, return np.array(rows, dtype=np.float32) +def _extract_window_events(ann_samples, ann_symbols, w_start, window_size, + beat_window, n_classes): + """Extract events that are fully contained within a single window. + + Each event is represented as ``(start, length)`` normalised to [0, 1] + over ``window_size``. An event centred at R-peak position ``s`` + (in segment coordinates) has: + + event_start = s - beat_window + event_length = 2 * beat_window (constant for all beats) + + The event is included **only if** its full span lies within the window: + + event_start >= w_start AND + event_start + event_length <= w_start + window_size + + Events that straddle a window boundary are discarded entirely. + + Parameters + ---------- + ann_samples : np.ndarray (A,) int annotation positions in segment coords + ann_symbols : list of str annotation symbols (len A) + w_start : int window start position in segment coords + window_size : int window length in samples (e.g. 512) + beat_window : int half-width of each event box in samples + n_classes : int K — number of AAMI classes + + Returns + ------- + events : np.ndarray (E, 2+K) float32 + Each row: [start_norm, length_norm, *one_hot_class] where + start_norm = (event_start - w_start) / window_size in [0, 1] + length_norm = event_length / window_size in [0, 1] + E=0 if no events are fully contained in the window. + """ + event_length = 2 * beat_window # constant: full span of every beat box + rows = [] + for sample, symbol in zip(ann_samples, ann_symbols): + aami_class = BEAT_CLASS.get(symbol) + if aami_class is None: + continue + class_idx = 0 if n_classes == 1 else aami_class - 1 + if class_idx >= n_classes: + continue + + event_start = sample - beat_window + + # Discard events whose span crosses a window boundary + if event_start < w_start or event_start + event_length > w_start + window_size: + continue + + one_hot = np.zeros(n_classes, dtype=np.float32) + one_hot[class_idx] = 1.0 + start_norm = (event_start - w_start) / window_size + length_norm = event_length / window_size + rows.append([start_norm, length_norm, *one_hot]) + + if not rows: + return np.zeros((0, 2 + n_classes), dtype=np.float32) + return np.array(rows, dtype=np.float32) + + class Dataset(BaseDataset): """MIT-BIH Arrhythmia Database for 1-D event detection. + Each record is split chronologically into train/test portions, then each + portion is sliced into fixed-size overlapping windows. Only beat events + that fit entirely within a window are kept; boundary-straddling events are + discarded. Positions are normalised to [0, 1] over ``window_size``. + Parameters ---------- record_ids : list of str or "all" @@ -136,6 +212,11 @@ class Dataset(BaseDataset): K — number of AAMI beat classes to distinguish (1–5). Classes are ordered N, S, V, F, Q; setting n_classes=1 collapses all annotated beats into a single "beat" class. + window_size : int + Length of each window in samples (default 512). + window_overlap : float + Fractional overlap between consecutive windows in [0, 1). + Default 0.5 → stride = window_size // 2 = 256 samples. """ name = "MITDB" @@ -148,6 +229,8 @@ class Dataset(BaseDataset): "train_ratio": [0.7], "beat_window": [36], "n_classes": [5], + "window_size": [512], + "window_overlap": [0.5], } def get_data(self): @@ -157,6 +240,8 @@ def get_data(self): if self.debug: record_ids = record_ids[:2] + stride = max(1, int(self.window_size * (1.0 - self.window_overlap))) + X_train, y_train, X_test, y_test = [], [], [], [] for rid in record_ids: signal, ann_samples, ann_symbols = _load_record(rid, data_dir) @@ -169,21 +254,33 @@ def get_data(self): split = max(1, int(len(signal) * self.train_ratio)) - for seg_signal, start, end, Xl, yl in [ + for seg_signal, seg_start, seg_end, Xl, yl in [ (signal[:split], 0, split, X_train, y_train), (signal[split:], split, len(signal), X_test, y_test), ]: - seg_ann = ann_samples[(ann_samples >= start) & (ann_samples < end)] - start + # Annotations relative to the segment (0-based within seg_signal) + seg_ann = (ann_samples[(ann_samples >= seg_start) & (ann_samples < seg_end)] + - seg_start) seg_sym = [s for s, idx in zip(ann_symbols, ann_samples) - if start <= idx < end] - Xl.append(seg_signal) - yl.append(_annotations_to_events( - len(seg_signal), seg_ann, seg_sym, - self.beat_window, self.n_classes, - )) + if seg_start <= idx < seg_end] + + seg_len = len(seg_signal) + if seg_len < self.window_size: + # Segment too short for even one window — skip + continue + + for w_start in range(0, seg_len - self.window_size + 1, stride): + window = seg_signal[w_start: w_start + self.window_size] + events = _extract_window_events( + seg_ann, seg_sym, + w_start, self.window_size, + self.beat_window, self.n_classes, + ) + Xl.append(window) + yl.append(events) # Pad all event arrays to a uniform N so solvers can np.stack them. - # N = max events in any single segment across train and test. + # N = max events in any single window across train and test. all_y = y_train + y_test max_n = max((y.shape[0] for y in all_y), default=0) n_cols = 2 + self.n_classes @@ -204,6 +301,6 @@ def _pad(arrays): X_test=X_test, y_test=_pad(y_test), task="event_detection", - metrics=["event_span_iou", "event_iou_f1"], + metrics=["map_iou"], n_classes=self.n_classes, ) From 8160f552447edbda9a544c538b5785f5151c299b Mon Sep 17 00:00:00 2001 From: Tianjun Hou Date: Fri, 29 May 2026 15:51:13 +0200 Subject: [PATCH 12/15] fix: fall back to ChronosPipeline for Chronos-1/T5 local checkpoints Chronos2Pipeline.from_pretrained raises AttributeError when the model config references T5ForConditionalGeneration (Chronos-1 architecture). load_model now catches AttributeError/ValueError and retries with the legacy ChronosPipeline, so local chronos_t5_* checkpoints work without any solver parameter changes. Co-authored-by: Cursor --- solvers/chronos.py | 305 ++++++++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 140 deletions(-) diff --git a/solvers/chronos.py b/solvers/chronos.py index 0b438bc..e40e872 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -7,7 +7,7 @@ - event_detection : EventHead trained on frozen encoder embeddings Model loading is done in ``set_objective`` (untimed). Inference batches -every (series, cutoff) pair into a single ``ChronosPipeline.predict`` +every (series, cutoff) pair into a single ``Chronos2Pipeline.predict`` call — the pipeline accepts a list of variable-length tensors and applies left-padding internally, so all the per-cutoff work happens in one forward pass. @@ -15,7 +15,6 @@ import numpy as np import torch -from benchopt import BaseSolver from chronos import ChronosPipeline from benchmark_utils.adapters import ( @@ -33,6 +32,7 @@ precompute_embeddings, ) from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter +from benchmark_utils.base_solver import BaseTSFMSolver from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput @@ -48,94 +48,31 @@ } -class _ChronosForecaster(BaseTSFMAdapter): - """Batched Chronos v1 adapter; quantiles are derived from sample draws.""" - - DEFAULT_QUANTILE_LEVELS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) - - def __init__(self, pipeline, prediction_length, quantile_levels=None): - self.pipeline = pipeline - self.prediction_length = prediction_length - self.quantile_levels = quantile_levels or self.DEFAULT_QUANTILE_LEVELS - - # ------------------------------------------------------------------ - # Template method — subclasses override _build_inputs / _assemble - # ------------------------------------------------------------------ - - def predict(self, x: ForecastInput) -> ForecastOutput: - inputs, layout, per_series_shape = self._build_inputs(x) - if not inputs: - return ForecastOutput(quantiles=[], quantile_levels=self.quantile_levels) - - with torch.no_grad(): - output = self.pipeline.predict( - inputs, - prediction_length=self.prediction_length, - ) - return self._assemble_output(output, layout, per_series_shape) - - def _build_inputs(self, x): - """Build list of 1-D tensors (one per channel) and track layout.""" - inputs = [] - layout = [] # (series_idx, cutoff_idx, channel_idx) - per_series_shape = [] # (C, n_cutoffs) - for series_idx, (series, cutoffs) in enumerate(zip(x.x, x.cutoff_indexes)): - series = np.asarray(series, dtype=np.float32) - if series.ndim == 1: - series = series[:, None] - _, C = series.shape - per_series_shape.append((C, len(cutoffs))) - for cutoff_idx, cutoff in enumerate(cutoffs): - hist = series[:cutoff] - for c in range(C): - inputs.append(torch.from_numpy(hist[:, c])) - layout.append((series_idx, cutoff_idx, c)) - return inputs, layout, per_series_shape - - def _assemble_output(self, samples, layout, per_series_shape): - """Derive quantile fan from Monte-Carlo sample draws.""" - # samples: (n_inputs, num_samples, H) - q_arr = np.quantile( - samples.float().cpu().numpy(), - q=list(self.quantile_levels), - axis=1, - ).transpose(1, 0, 2) # (n_inputs, Q, H) - - Q = len(self.quantile_levels) - per_series = [ - np.empty((n_cutoffs, Q, self.prediction_length, C), dtype=np.float32) - for C, n_cutoffs in per_series_shape - ] - for i, (series_idx, cutoff_idx, c) in enumerate(layout): - per_series[series_idx][cutoff_idx, :, :, c] = q_arr[i] - - return ForecastOutput(quantiles=per_series, quantile_levels=self.quantile_levels) +def _to_context(x): + """Reshape ``(T, V)`` or ``(B, T, V)`` to Chronos input ``(B, V, T)``.""" + x = np.asarray(x, dtype=np.float32) + if x.ndim == 2: + x = x[None] + return x.transpose(0, 2, 1) class _ChronosEmbedEncoder(UnpooledEncoder): - """Default path — uses ``ChronosPipeline.embed``. + """Default path — uses ``Chronos2Pipeline.embed``. - Returns hidden states *after* ``encoder.final_layer_norm``. + Returns hidden states *after* ``encoder.final_layer_norm`` for each + series in the batch. """ def __init__(self, pipeline): self.pipeline = pipeline def encode(self, X) -> np.ndarray: - # X: (B, T, V) or (T, V). - X = np.asarray(X, dtype=np.float32) - batched = X.ndim == 3 - if not batched: - X = X[None] # (1, T, V) - B, T, V = X.shape - - # Chronos is univariate — flatten B & V into the batch axis. - flat = X.reshape(B * V, T) # (B*V, T) + context = _to_context(X) # (B, V, T) with torch.no_grad(): - emb, _ = self.pipeline.embed(torch.from_numpy(flat)) # (B*V, T_tok, D) - - # (B*V, T_tok, D) -> (B, T_tok, V, D) - return emb.float().cpu().numpy().reshape(B, -1, V, emb.shape[-1]) + # embed returns a list of B tensors, each of shape (V, T, D). + embeddings, _ = self.pipeline.embed(context) + stacked = torch.stack(list(embeddings)) # (B, V, T, D) + return stacked.transpose(1, 2).float().cpu().numpy() # (B, T, V, D) class _ChronosHookEncoder(UnpooledEncoder): @@ -154,15 +91,8 @@ def __init__(self, pipeline, layer: int): ) self._block_idx = layer % n_blocks - def encode(self, X) -> np.ndarray: - X = np.asarray(X, dtype=np.float32) - batched = X.ndim == 3 - if not batched: - X = X[None] # (1, T, V) - B, T, V = X.shape - - flat = X.reshape(B * V, T) # (B*V, T) - context = torch.from_numpy(flat) + def encode(self, x: np.ndarray) -> np.ndarray: + context = _to_context(x) # (B, V, T) token_ids, attn_mask, _ = self.pipeline.tokenizer.context_input_transform( context ) @@ -174,6 +104,8 @@ def encode(self, X) -> np.ndarray: captured = {} def _hook(_module, _inputs, output): + # Hook to capture the embeddings while performing a forward pass + # T5Block returns a tuple; first element is the hidden state. hidden = output[0] if isinstance(output, tuple) else output captured["h"] = hidden.detach() @@ -184,8 +116,8 @@ def _hook(_module, _inputs, output): finally: handle.remove() - # (B*V, T_tok, D) -> (B, T_tok, V, D) - return captured["h"].float().cpu().numpy().reshape(B, -1, V, captured["h"].shape[-1]) + # (C, T_tok, D) -> (T_tok, C, D) + return captured["h"].transpose(0, 1).float().cpu().numpy() def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: @@ -224,8 +156,8 @@ def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: # --------------------------------------------------------------------------- -class Solver(BaseSolver): - """Chronos zero-shot solver. +class Solver(BaseTSFMSolver): + """Chronos-2 zero-shot solver. Parameters ---------- @@ -243,9 +175,7 @@ class Solver(BaseSolver): name = "Chronos" - requirements = ["pip::chronos-forecasting>=2.2", "pip::torch"] - - sampling_strategy = "run_once" + requirements = ["pip::chronos-forecasting>=2.2,<3"] parameters = { "model_size": ["small"], @@ -263,43 +193,134 @@ class Solver(BaseSolver): "num_queries": [10], } - def skip(self, task, **kwargs): - if task not in SUPPORTED_TASKS: - return True, f"Chronos solver does not support task={task!r}" - return False, None - - def set_objective(self, X_train, y_train, task, **meta): - self.task = task - self.X_train = X_train - self.y_train = y_train - self.meta = meta - - # bfloat16 is fine on CUDA but poorly supported on CPU / MPS; - # fall back to float32 there so inference doesn't crash or stall. - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = torch.bfloat16 if device == "cuda" else torch.float32 - - model_id = self.model_path if self.model_path else f"amazon/chronos-t5-{self.model_size}" + def __init__( + self, + model_size="small", + layer=None, + pooler="mean", + model_path="", + batch_size=32, + num_epochs=100, + lr=3e-4, + weight_decay=1e-4, + warmup_epochs=5, + num_dec_layers=2, + lambda_cls=1.0, + num_queries=10, + ): + """Initialize Chronos-specific state. + + Parameters + ---------- + model_size : str, default="small" + Chronos model variant to load. + layer : int or None, default=None + Encoder block index for classification embeddings. + pooler : {"mean", "max", "last"}, default="mean" + Pooling strategy over the time-token axis for classification. + model_path : str, default="" + Local model directory; empty = load from HuggingFace Hub. + """ + super().__init__( + model_size=model_size, + layer=layer, + pooler=pooler, + model_path=model_path, + batch_size=batch_size, + num_epochs=num_epochs, + lr=lr, + weight_decay=weight_decay, + warmup_epochs=warmup_epochs, + num_dec_layers=num_dec_layers, + lambda_cls=lambda_cls, + num_queries=num_queries, + ) + self._pipeline = None + self._loaded_model = None + + @property + def supported_tasks(self): + return SUPPORTED_TASKS + + def load_model(self, device, dtype): + """Load Chronos pipeline (cached if already loaded). + + Chronos-2 models (autogluon/chronos-2-*) are loaded via + ``Chronos2Pipeline``; Chronos-1 / T5-based models (e.g. a local + ``chronos_t5_*`` checkpoint) are loaded via ``ChronosPipeline``. + Detection is done by inspecting the ``architectures`` field of the + model config; if ``Chronos2Pipeline`` raises ``AttributeError`` the + loader falls back to ``ChronosPipeline`` automatically. + """ + from chronos import ChronosPipeline, Chronos2Pipeline + + model_id = f"autogluon/chronos-2-{self.model_size}" + model_id = self.model_path if self.model_path else model_id if not hasattr(self, "_pipeline") or self._loaded_model != model_id: - self._pipeline = ChronosPipeline.from_pretrained( - model_id, - device_map=device, - dtype=dtype, - ) + try: + self._pipeline = Chronos2Pipeline.from_pretrained( + model_id, + device_map=device, + dtype=dtype, + ) + except (AttributeError, ValueError): + # Chronos-1 / T5-based checkpoint — fall back to ChronosPipeline + self._pipeline = ChronosPipeline.from_pretrained( + model_id, + device_map=device, + dtype=dtype, + ) self._loaded_model = model_id + return self._pipeline + + def set_objective(self, X_train, y_train, task, **meta): + """Load pipeline then pre-compute embeddings for event_detection.""" + super().set_objective(X_train, y_train, task, **meta) if task == "event_detection": - self._Z_train = precompute_embeddings(self._pipeline, X_train) - self._d_model = self._Z_train[0].shape[-1] - self._n_classes = np.asarray(y_train[0]).shape[-1] - 2 + self._n_classes = int(meta["n_classes"]) + self._d_model = _CHRONOS_D.get(self.model_size, 512) + self._Z_train = precompute_embeddings(self.model, X_train) + + def forecast_batch(self, inputs): + """Chronos-specific batch prediction. + + Parameters + ---------- + inputs : list of torch.Tensor + Each tensor shape (C, T_cutoff) + + Returns + ------- + list of torch.Tensor + Each tensor shape (C, Q, H) + """ + with torch.no_grad(): + return self.model.predict(inputs, prediction_length=self.prediction_length) - def run(self, _): + def build_adapter(self, task, model): + # TODO later: put that code in base_solver.py + # and make it rely on .forecast(), .embed() and .time_embed() only, once those are all properly coded + """Create task-specific adapter for Chronos.""" pred_len = self.meta.get("prediction_length", 1) - if self.task == "forecasting": - self._adapter = _ChronosForecaster(self._pipeline, pred_len) - elif self.task == "classification": - base_encoder = ChronosEncoder(self._pipeline, layer=self.layer) + if task == "forecasting": + self.prediction_length = pred_len + quantile_levels = tuple(float(q) for q in model.quantiles) + + # Create a simple adapter that calls self.forecast() + class _ForecastAdapter(BaseTSFMAdapter): + def __init__(self, solver, quantile_levels): + self.solver = solver + self.quantile_levels = quantile_levels + + def predict(self, x: ForecastInput) -> ForecastOutput: + return self.solver.forecast(x, self.solver.prediction_length, self.quantile_levels) + + return _ForecastAdapter(self, quantile_levels) + + elif task == "classification": + base_encoder = ChronosEncoder(model, layer=self.layer) encoder = Encoder(base_encoder, POOLERS[self.pooler]()) adapter = LinearProbeAdapter( encoder, @@ -307,16 +328,26 @@ def run(self, _): n_classes=self.meta.get("n_classes"), ) adapter.fit(self.X_train, self.y_train) - self._adapter = adapter + return adapter - elif self.task == "anomaly_detection": + elif task == "anomaly_detection": # AD uses one-step-ahead forecasts. - self._adapter = ForecastResidualAdapter( - _ChronosForecaster(self._pipeline, prediction_length=1), - prediction_length=1, - ) + self.prediction_length = 1 + quantile_levels = (0.5,) + + # Create a forecaster adapter for residual-based anomaly detection + class _ForecasterForAD(BaseTSFMAdapter): + def __init__(self, solver, quantile_levels): + self.solver = solver + self.quantile_levels = quantile_levels + + def predict(self, x: ForecastInput) -> ForecastOutput: + return self.solver.forecast(x, 1, self.quantile_levels) - elif self.task == "event_detection": + forecaster = _ForecasterForAD(self, quantile_levels) + return ForecastResidualAdapter(forecaster, prediction_length=1) + + elif task == "event_detection": device = "cuda" if torch.cuda.is_available() else "cpu" head = fit_event_head( self._Z_train, self.y_train, self._n_classes, self._d_model, @@ -324,12 +355,6 @@ def run(self, _): self.weight_decay, self.warmup_epochs, self.num_dec_layers, self.lambda_cls, self.num_queries, ) - self._adapter = ChronosEventAdapter( - self._pipeline, head, device, self._n_classes - ) - - else: - raise ValueError(f"Unknown task: {self.task}") + return ChronosEventAdapter(model, head, device, self._n_classes) - def get_result(self): - return {"model": self._adapter} + raise ValueError(f"Unknown task: {task}") From 9c08ebb742d0d1c20665501da98fe72619d073f1 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 16:22:08 +0200 Subject: [PATCH 13/15] adding a batch size for embeddings --- benchmark_utils/adapters/event_detection.py | 48 +++++++++++++++------ datasets/mitdb.py | 4 +- solvers/chronos.py | 2 +- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py index 3f6fd56..6435de2 100644 --- a/benchmark_utils/adapters/event_detection.py +++ b/benchmark_utils/adapters/event_detection.py @@ -241,7 +241,7 @@ def compute_loss( # --- class loss: matched → GT class, unmatched → zero target ------- cls_target = torch.zeros_like(cls_logits_b) # (N, k) all zeros - cls_target[pred_idx] = gt_cls[gt_idx] # matched slots get real labels + cls_target[pred_idx] = gt_cls[gt_idx].to(cls_target.dtype) cls_loss_b = nn.functional.binary_cross_entropy_with_logits( cls_logits_b, cls_target, reduction="mean" ) @@ -354,7 +354,7 @@ def lr_lambda(epoch): return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) -def precompute_embeddings(pipeline, X_train): +def precompute_embeddings(pipeline, X_train, batch_size=128): """Embed every training series via a frozen Chronos pipeline. Each series is embedded channel-by-channel and the per-channel @@ -367,23 +367,43 @@ def precompute_embeddings(pipeline, X_train): A loaded (and frozen) Chronos pipeline exposing ``.embed()``. X_train : List[np.ndarray (T, C)] Training time series. + batch_size : int + Number of series to embed in a single ``pipeline.embed()`` call. + Each channel is treated as a separate univariate series, so the + actual call size is ``batch_size * C``. Returns ------- List[torch.Tensor] — one (T_tok, D) float32 CPU tensor per series. """ Z_train = [] + n = len(X_train) with torch.no_grad(): - for x in X_train: - x = np.asarray(x, dtype=np.float32) - C = x.shape[1] - channel_embs = [] - for c in range(C): - ctx = torch.tensor(x[:, c], dtype=torch.float32) - emb, _ = pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) - channel_embs.append(emb.squeeze(0).float().cpu()) - stacked = torch.stack(channel_embs, dim=0) # (C, T_tok, D) - Z_train.append(stacked.mean(dim=0)) # (T_tok, D) + for start in range(0, n, batch_size): + batch = X_train[start:start + batch_size] + # Flatten all channels across the batch into one list of 1D tensors. + # Order: x0c0, x0c1, ..., x1c0, x1c1, ... + contexts = [] + channel_counts = [] + for x in batch: + x = np.asarray(x, dtype=np.float32) + C = x.shape[1] + channel_counts.append(C) + for c in range(C): + contexts.append(torch.tensor(x[:, c], dtype=torch.float32)) + + embs, _ = pipeline.embed(contexts) # (sum(C_i), T_tok, D) + embs = embs.float().cpu() + + idx = 0 + for C in channel_counts: + channel_embs = embs[idx:idx + C] # (C, T_tok, D) + Z_train.append(channel_embs.mean(dim=0)) # (T_tok, D) + idx += C + + print(f" Embedded {min(start + batch_size, n)}/{n} series", end="\r") + + print() return Z_train @@ -478,8 +498,8 @@ def fit_event_head( scheduler = _get_linear_cosine_scheduler(optimizer, warmup_epochs, num_epochs) N_train = len(Z_train) - use_amp = device == "cuda" - scaler = torch.amp.GradScaler("cuda", enabled=use_amp) + use_amp = "cuda" in device + scaler = torch.amp.GradScaler(device, enabled=use_amp) for epoch in range(num_epochs): indices = np.random.permutation(N_train) diff --git a/datasets/mitdb.py b/datasets/mitdb.py index d761755..4d392cd 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -30,7 +30,7 @@ class 4 Q — Unknown / pacemaker artefact X_test : List[np.ndarray (512, 2)] test windows y_test : List[np.ndarray (N, 2+K)] same format, same N task : "event_detection" -metrics : ["map_iou"] +metrics : ["map_iou", "event_iou_f1"] extra : n_classes (int) K above """ @@ -301,6 +301,6 @@ def _pad(arrays): X_test=X_test, y_test=_pad(y_test), task="event_detection", - metrics=["map_iou"], + metrics=["map_iou", "event_iou_f1"], n_classes=self.n_classes, ) diff --git a/solvers/chronos.py b/solvers/chronos.py index e40e872..768fcb6 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -280,7 +280,7 @@ def set_objective(self, X_train, y_train, task, **meta): if task == "event_detection": self._n_classes = int(meta["n_classes"]) self._d_model = _CHRONOS_D.get(self.model_size, 512) - self._Z_train = precompute_embeddings(self.model, X_train) + self._Z_train = precompute_embeddings(self.model, X_train, batch_size=self.batch_size) def forecast_batch(self, inputs): """Chronos-specific batch prediction. From eb55538ab62b280e9d6f051f6e3a7859304b3d91 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 16:49:27 +0200 Subject: [PATCH 14/15] more prints+debug --- benchmark_utils/adapters/event_detection.py | 20 ++++++++++++++------ solvers/chronos.py | 16 +++++++++++++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py index 6435de2..e9ccbd0 100644 --- a/benchmark_utils/adapters/event_detection.py +++ b/benchmark_utils/adapters/event_detection.py @@ -500,6 +500,7 @@ def fit_event_head( N_train = len(Z_train) use_amp = "cuda" in device scaler = torch.amp.GradScaler(device, enabled=use_amp) + num_batches_per_epoch = max(1, int(np.ceil(N_train / batch_size))) for epoch in range(num_epochs): indices = np.random.permutation(N_train) @@ -541,15 +542,22 @@ def fit_event_head( epoch_loss += loss.item() num_batches += 1 - scheduler.step() - - if (epoch + 1) % 10 == 0 or epoch == 0: - avg = epoch_loss / max(num_batches, 1) - lr_now = scheduler.get_last_lr()[0] print( f" Epoch {epoch + 1:3d}/{num_epochs} | " - f"loss={avg:.4f} | lr={lr_now:.2e}" + f"step {num_batches:4d}/{num_batches_per_epoch} | " + f"loss={loss.item():.4f}", + end="\r", ) + scheduler.step() + + avg = epoch_loss / num_batches + lr_now = scheduler.get_last_lr()[0] + print( + f" Epoch {epoch + 1:3d}/{num_epochs} | " + f"loss={avg:.4f} | lr={lr_now:.2e}" + + " " * 20 # clear leftover step line + ) + head.eval() return head diff --git a/solvers/chronos.py b/solvers/chronos.py index 4a002ee..dea9442 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -258,11 +258,11 @@ class Solver(BaseTSFMSolver): "pooler": ["mean"], # event_detection — single values so no cross-product for other tasks "model_path": [""], - "batch_size": [32], - "num_epochs": [100], + "batch_size": [64], + "num_epochs": [3], "lr": [3e-4], "weight_decay": [1e-4], - "warmup_epochs": [5], + "warmup_epochs": [1], "num_dec_layers": [2], "lambda_cls": [1.0], "num_queries": [10], @@ -287,6 +287,11 @@ def __init__( num_dec_layers=2, lambda_cls=1.0, num_queries=10, + classifier="log_reg", + penalty="l2", + C=1.0, + alpha=1.0, + n_iterators=100, ): """Initialize Chronos-specific state. @@ -314,6 +319,11 @@ def __init__( num_dec_layers=num_dec_layers, lambda_cls=lambda_cls, num_queries=num_queries, + classifier=classifier, + penalty=penalty, + C=C, + alpha=alpha, + n_iterators=n_iterators, ) self._pipeline = None self._loaded_model = None From f0a10caaba5dcc440705a482dfd2096f290f97b2 Mon Sep 17 00:00:00 2001 From: ABarneche Date: Fri, 29 May 2026 17:02:36 +0200 Subject: [PATCH 15/15] fix moment head --- solvers/moment.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/solvers/moment.py b/solvers/moment.py index 65eb5e1..2db0380 100644 --- a/solvers/moment.py +++ b/solvers/moment.py @@ -62,7 +62,7 @@ def predict(self, x: ForecastInput) -> ForecastOutput: preds_per_series = [] for cutoff in cutoffs: hist = series[:cutoff] # (T_cutoff, C) - + if hist.ndim == 1: hist = hist[None, :] @@ -105,12 +105,12 @@ def predict(self, x: ForecastInput) -> ForecastOutput: ) preds_per_series.append(arr) - + # Stack predictions: (n_cutoffs, prediction_length, C) stacked = np.stack(preds_per_series, axis=0) # Add quantile dimension: (n_cutoffs, 1, prediction_length, C) quantiles.append(stacked[:, None, :, :]) - + return ForecastOutput(quantiles=quantiles, quantile_levels=(0.5,)) @@ -122,7 +122,7 @@ def __init__(self, pipeline): def encode(self, X) -> np.ndarray: """Extract embeddings from time series data. - + Args: X: np.ndarray of shape (T, C) or (B, T, C) @@ -174,7 +174,6 @@ class Solver(BaseSolver): "pip::moment @ git+https://github.com/moment-timeseries-foundation-model/moment.git", ] - sampling_strategy = "run_once" parameters = { @@ -182,17 +181,11 @@ class Solver(BaseSolver): "task_config": ["forecasting"], # forecasting or classification "pooler": ["mean"], # pooler for classification embeddings "batch_size": [32], -<<<<<<< HEAD - "classifier": ["logistic_regression"], - "max_iter": [1000], - "n_estimators": [100], -======= "classifier": ["log_reg"], "penalty": ["l2"], "C": [1.0], "alpha": [1.0], "n_iterators": [100], ->>>>>>> c573cd85c2b80ced694b62bda534833d5461cee5 } def skip(self, task, **kwargs):