diff --git a/benchmark_utils/adapters/__init__.py b/benchmark_utils/adapters/__init__.py index c5d9f32..6039249 100644 --- a/benchmark_utils/adapters/__init__.py +++ b/benchmark_utils/adapters/__init__.py @@ -9,6 +9,7 @@ ) from .linear_probe import LinearProbeAdapter from .forecast_residual import ForecastResidualAdapter +from .event_detection import EventHead, ChronosEventAdapter __all__ = [ "BaseTSFMAdapter", @@ -20,4 +21,6 @@ "Encoder", "LinearProbeAdapter", "ForecastResidualAdapter", + "EventHead", + "ChronosEventAdapter", ] diff --git a/benchmark_utils/adapters/event_detection.py b/benchmark_utils/adapters/event_detection.py new file mode 100644 index 0000000..e9ccbd0 --- /dev/null +++ b/benchmark_utils/adapters/event_detection.py @@ -0,0 +1,563 @@ +"""Event-detection adapter for a frozen Chronos encoder + trainable head. + +Architecture +------------ +x (T, C) + -> Chronos.embed per channel (T_tok, D) x C + -> mean-pool across channels (T_tok, D) [memory key/value] + -> EventHead Transformer-decoder [10 learned queries] + -> pos_head : Linear(D,2) -> sigmoid (start, length) in [0,1] + -> cls_head : Linear(D,k) -> logits k binary class logits + +Output shape per series: (N=10, 2+k) + +Channel handling +---------------- +Chronos is strictly univariate. We embed every channel independently and then +**mean-pool the per-channel (T_tok, D) tensors**. Alternatives are possible: + + * concat along D: increases D by C-fold, requires retraining head per C + * attention pool: add a learned aggregation layer (more parameters) + +Mean-pool is simple, parameter-free, and works for any C at inference time. + +Target format +------------- +y per series: (N=10, 2+k) float32, all-zero rows = empty / no-event slots. + col 0 : start (normalised to [0,1] over T=512) + col 1 : length (normalised to [0,1] over T=512) + col 2..2+k : binary multi-class one-hot columns (sum >= 1 for real events) + +Loss — DETR-style Hungarian matching +------------------------------------- +Training uses per-sample optimal bipartite assignment (Hungarian algorithm) +between the N predicted slots and the M real ground-truth events (M <= N): + + 1. Build cost matrix (N x M) per sample: + cost_pos[n,m] = L1(sigmoid(pos_logits[n]), gt_pos[m]) (sum over 2 dims) + cost_cls[n,m] = -sum_k sigmoid(cls_logits[n,k]) * gt_cls[m,k] + cost[n,m] = cost_pos[n,m] + lambda_cls * cost_cls[n,m] + 2. Solve: scipy.optimize.linear_sum_assignment(cost) -> (pred_idx, gt_idx) + 3. Position loss: smooth_l1 on matched (pred, gt) span pairs. + 4. Class loss: BCEWithLogitsLoss where + - matched slots -> their assigned GT class vector + - unmatched slots -> all-zero target (no-event) + 5. Average losses over batch. + +This gives the model permutation invariance: the N queries can predict events +in any order and the loss always finds the optimal pairing. +""" + +import numpy as np +import torch +import torch.nn as nn + +from .base import BaseTSFMAdapter + + +# --------------------------------------------------------------------------- +# EventHead +# --------------------------------------------------------------------------- + +class EventHead(nn.Module): + """Transformer-decoder that turns Chronos embeddings into span predictions. + + Parameters + ---------- + d_model : int + Embedding dimension coming out of the Chronos encoder. + n_classes : int + Number of binary event-class columns k. + num_queries : int + Fixed number of event slots N (default 10). + num_decoder_layers : int + Depth of the Transformer decoder (default 2). + nhead : int + Number of attention heads (default 8 when d_model >= 512). + dim_feedforward : int + FFN inner dimension (default 4 * d_model). + dropout : float + Dropout rate in the decoder (default 0.1). + """ + + def __init__( + self, + d_model: int, + n_classes: int, + num_queries: int = 10, + num_decoder_layers: int = 2, + nhead: int = 8, + dim_feedforward: int | None = None, + dropout: float = 0.1, + ): + super().__init__() + self.d_model = d_model + self.n_classes = n_classes + self.num_queries = num_queries + + if dim_feedforward is None: + dim_feedforward = 4 * d_model + + # Ensure nhead divides d_model cleanly + while d_model % nhead != 0 and nhead > 1: + nhead //= 2 + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + batch_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer, + num_layers=num_decoder_layers, + ) + + # N learnable query embeddings — one per event slot + self.query_embed = nn.Embedding(num_queries, d_model) + + # Output heads + self.pos_head = nn.Sequential( + nn.Linear(d_model, d_model // 2), + nn.ReLU(), + nn.Linear(d_model // 2, 2), + ) + self.cls_head = nn.Sequential( + nn.Linear(d_model, d_model // 2), + nn.ReLU(), + nn.Linear(d_model // 2, n_classes), + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + nn.init.normal_(self.query_embed.weight, std=0.02) + + def forward(self, memory: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Parameters + ---------- + memory : (B, T_tok, D) — Chronos encoder embeddings (mean-pooled over C) + + Returns + ------- + pos_logits : (B, N, 2) — raw (pre-sigmoid) span predictions + cls_logits : (B, N, k) — raw (pre-sigmoid) class logits + """ + B = memory.size(0) + # Expand queries across batch + queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1) # (B, N, D) + + decoded = self.decoder(tgt=queries, memory=memory) # (B, N, D) + + pos_logits = self.pos_head(decoded) # (B, N, 2) + cls_logits = self.cls_head(decoded) # (B, N, k) + return pos_logits, cls_logits + + def compute_loss( + self, + pos_logits: torch.Tensor, + cls_logits: torch.Tensor, + y: torch.Tensor, + lambda_cls: float = 1.0, + ) -> torch.Tensor: + """Compute combined position + classification loss with Hungarian matching. + + For each sample in the batch the optimal bipartite assignment between + the N predicted slots and the M real ground-truth events is found via + the Hungarian algorithm. Matched slots are trained on their assigned + GT event; unmatched slots are trained toward the no-event class target + (all-zero class vector). + + Parameters + ---------- + pos_logits : (B, N, 2) — raw span predictions (sigmoid applied inside) + cls_logits : (B, N, k) — raw class logits + y : (B, N, 2+k) — ground-truth targets, float32; + all-zero rows are empty / no-event slots + lambda_cls : float — weight for the class loss term + + Returns + ------- + scalar loss tensor + """ + from scipy.optimize import linear_sum_assignment + + B, N, _ = pos_logits.shape + y_pos = y[..., :2] # (B, N, 2) + y_cls = y[..., 2:] # (B, N, k) + + pos_pred = torch.sigmoid(pos_logits) # (B, N, 2) in [0,1] + cls_prob = torch.sigmoid(cls_logits) # (B, N, k) in [0,1] + + total_pos_loss = pos_logits.new_zeros(()) + total_cls_loss = pos_logits.new_zeros(()) + + for b in range(B): + # Identify real GT events for this sample + has_event_b = y_cls[b].sum(dim=-1) > 0 # (N,) bool mask over GT slots + gt_pos = y_pos[b][has_event_b] # (M, 2) + gt_cls = y_cls[b][has_event_b] # (M, k) + M = gt_pos.shape[0] + + pred_pos_b = pos_pred[b] # (N, 2) + cls_logits_b = cls_logits[b] # (N, k) + cls_prob_b = cls_prob[b] # (N, k) — used for cost matrix only + + if M == 0: + # No GT events: drive all slots to no-event (zero target) + total_cls_loss = total_cls_loss + nn.functional.binary_cross_entropy_with_logits( + cls_logits_b, + torch.zeros_like(cls_logits_b), + reduction="mean", + ) + continue + + # --- cost matrix (N x M) ---------------------------------------- + # L1 position cost: sum of |pred - gt| over the 2 span dimensions + with torch.no_grad(): + cost_pos = torch.cdist(pred_pos_b, gt_pos, p=1) # (N, M) + # Class cost: negative dot product of sigmoid probabilities and + # GT class vectors — lower cost means better class agreement. + cost_cls = -(cls_prob_b @ gt_cls.T) # (N, M) + cost = cost_pos + lambda_cls * cost_cls # (N, M) + + pred_idx, gt_idx = linear_sum_assignment(cost.cpu().numpy()) + + # --- position loss on matched pairs -------------------------------- + matched_pred_pos = pred_pos_b[pred_idx] # (M, 2) + matched_gt_pos = gt_pos[gt_idx] # (M, 2) + pos_loss_b = nn.functional.smooth_l1_loss( + matched_pred_pos, matched_gt_pos, reduction="mean" + ) + total_pos_loss = total_pos_loss + pos_loss_b + + # --- class loss: matched → GT class, unmatched → zero target ------- + cls_target = torch.zeros_like(cls_logits_b) # (N, k) all zeros + cls_target[pred_idx] = gt_cls[gt_idx].to(cls_target.dtype) + cls_loss_b = nn.functional.binary_cross_entropy_with_logits( + cls_logits_b, cls_target, reduction="mean" + ) + total_cls_loss = total_cls_loss + cls_loss_b + + return (total_pos_loss + lambda_cls * total_cls_loss) / B + + +# --------------------------------------------------------------------------- +# ChronosEventAdapter +# --------------------------------------------------------------------------- + +class ChronosEventAdapter(BaseTSFMAdapter): + """Fitted adapter: frozen Chronos encoder + trained EventHead. + + Parameters + ---------- + pipeline : ChronosPipeline + A loaded (and frozen) Chronos pipeline instance. + head : EventHead + A trained EventHead instance on the target device. + device : str + torch device string, e.g. "cuda" or "cpu". + n_classes : int + Number of binary event-class columns k. + T : int + Input series length (used only for documentation; not enforced here). + """ + + def __init__(self, pipeline, head: EventHead, device: str, + n_classes: int, T: int = 512): + self.pipeline = pipeline + self.head = head + self.device = device + self.n_classes = n_classes + self.T = T + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _embed_series(self, x: np.ndarray) -> torch.Tensor: + """Embed a single multivariate series via Chronos, mean-pool over C. + + Parameters + ---------- + x : (T, C) float array + + Returns + ------- + Tensor of shape (T_tok, D) on CPU, float32 + """ + import torch as _torch + C = x.shape[1] + channel_embs = [] + for c in range(C): + ctx = _torch.tensor(x[:, c], dtype=_torch.float32) + # pipeline.embed returns (1, T_tok, D), tokenizer_state + emb, _ = self.pipeline.embed(ctx.unsqueeze(0)) # (1, T_tok, D) + channel_embs.append(emb.squeeze(0).float().cpu()) # (T_tok, D) + # Stack and mean-pool over channels + stacked = _torch.stack(channel_embs, dim=0) # (C, T_tok, D) + return stacked.mean(dim=0) # (T_tok, D) + + # ------------------------------------------------------------------ + # BaseTSFMAdapter interface + # ------------------------------------------------------------------ + + def predict(self, x: np.ndarray) -> np.ndarray: + """Run inference on a single multivariate series. + + Parameters + ---------- + x : (T, C) float array + + Returns + ------- + spans : (N=10, 2+k) float32 array + Columns 0-1 : start and length in [0,1] + Columns 2.. : binary class probabilities in [0,1] + """ + self.head.eval() + memory = self._embed_series(x) # (T_tok, D) cpu + memory = memory.unsqueeze(0).to(self.device) # (1, T_tok, D) + + with torch.no_grad(): + pos_logits, cls_logits = self.head(memory) # (1,N,2), (1,N,k) + + pos = torch.sigmoid(pos_logits[0]).cpu().numpy() # (N, 2) + cls = torch.sigmoid(cls_logits[0]).cpu().numpy() # (N, k) + return np.concatenate([pos, cls], axis=-1).astype(np.float32) # (N, 2+k) + + +# --------------------------------------------------------------------------- +# Training helpers (used by solvers/chronos.py) +# --------------------------------------------------------------------------- + +def _get_linear_cosine_scheduler(optimizer, warmup_epochs, total_epochs): + """Linear warmup + cosine annealing LR scheduler (epoch-level).""" + import math + + def lr_lambda(epoch): + if epoch < warmup_epochs: + return float(epoch + 1) / float(max(1, warmup_epochs)) + progress = float(epoch - warmup_epochs) / float( + max(1, total_epochs - warmup_epochs) + ) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +def precompute_embeddings(pipeline, X_train, batch_size=128): + """Embed every training series via a frozen Chronos pipeline. + + Each series is embedded channel-by-channel and the per-channel + (T_tok, D) tensors are mean-pooled to produce a single (T_tok, D) + representation. Results are cached on CPU. + + Parameters + ---------- + pipeline : ChronosPipeline or Chronos2Pipeline + A loaded (and frozen) Chronos pipeline exposing ``.embed()``. + X_train : List[np.ndarray (T, C)] + Training time series. + batch_size : int + Number of series to embed in a single ``pipeline.embed()`` call. + Each channel is treated as a separate univariate series, so the + actual call size is ``batch_size * C``. + + Returns + ------- + List[torch.Tensor] — one (T_tok, D) float32 CPU tensor per series. + """ + Z_train = [] + n = len(X_train) + with torch.no_grad(): + for start in range(0, n, batch_size): + batch = X_train[start:start + batch_size] + # Flatten all channels across the batch into one list of 1D tensors. + # Order: x0c0, x0c1, ..., x1c0, x1c1, ... + contexts = [] + channel_counts = [] + for x in batch: + x = np.asarray(x, dtype=np.float32) + C = x.shape[1] + channel_counts.append(C) + for c in range(C): + contexts.append(torch.tensor(x[:, c], dtype=torch.float32)) + + embs, _ = pipeline.embed(contexts) # (sum(C_i), T_tok, D) + embs = embs.float().cpu() + + idx = 0 + for C in channel_counts: + channel_embs = embs[idx:idx + C] # (C, T_tok, D) + Z_train.append(channel_embs.mean(dim=0)) # (T_tok, D) + idx += C + + print(f" Embedded {min(start + batch_size, n)}/{n} series", end="\r") + + print() + return Z_train + + +def _pad_or_truncate_labels(y: np.ndarray, num_queries: int) -> np.ndarray: + """Pad or truncate a label array to exactly ``num_queries`` rows. + + Parameters + ---------- + y : (N, 2+k) float array + Raw event labels for one series. ``N`` may be smaller or larger + than ``num_queries``. + num_queries : int + Target number of event slots. + + Returns + ------- + (num_queries, 2+k) float32 array + Rows beyond the original ``N`` are filled with zeros (no-event). + Rows beyond ``num_queries`` in the original are discarded. + """ + y = np.asarray(y, dtype=np.float32) + N, width = y.shape + if N == num_queries: + return y + if N > num_queries: + return y[:num_queries] + # N < num_queries — pad with zero rows + pad = np.zeros((num_queries - N, width), dtype=np.float32) + return np.concatenate([y, pad], axis=0) + + +def fit_event_head( + Z_train, + y_train, + n_classes, + d_model, + device, + batch_size=32, + num_epochs=100, + lr=3e-4, + weight_decay=1e-4, + warmup_epochs=5, + num_dec_layers=2, + lambda_cls=1.0, + num_queries=10, +): + """Train an EventHead on pre-computed Chronos embeddings. + + Parameters + ---------- + Z_train : List[torch.Tensor (T_tok, D)] + Pre-computed encoder embeddings (CPU), one per training series. + y_train : List[np.ndarray (N_i, 2+k)] + Event targets, one per training series. ``N_i`` may differ across + series and need not equal ``num_queries``; labels are automatically + padded (with zeros) or truncated to ``num_queries`` rows. + n_classes : int + Number of binary class columns k. + d_model : int + Encoder hidden dimension D. + device : str + Torch device, e.g. ``"cuda"`` or ``"cpu"``. + batch_size, num_epochs, lr, weight_decay : training hyperparameters. + warmup_epochs : int + Linear warmup duration; cosine decay thereafter. + num_dec_layers : int + Transformer decoder depth. + lambda_cls : float + Weight of the classification loss relative to the position loss. + num_queries : int + Number of event slots (decoder queries); default 10. + + Returns + ------- + EventHead — trained, in eval mode, on ``device``. + """ + # Normalise all labels to (num_queries, 2+k) once, before the training loop + y_train = [_pad_or_truncate_labels(y, num_queries) for y in y_train] + + head = EventHead( + d_model=d_model, + n_classes=n_classes, + num_queries=num_queries, + num_decoder_layers=num_dec_layers, + nhead=8, + ).to(device) + head.train() + + optimizer = torch.optim.AdamW( + head.parameters(), lr=lr, weight_decay=weight_decay + ) + scheduler = _get_linear_cosine_scheduler(optimizer, warmup_epochs, num_epochs) + + N_train = len(Z_train) + use_amp = "cuda" in device + scaler = torch.amp.GradScaler(device, enabled=use_amp) + num_batches_per_epoch = max(1, int(np.ceil(N_train / batch_size))) + + for epoch in range(num_epochs): + indices = np.random.permutation(N_train) + epoch_loss = 0.0 + num_batches = 0 + + for batch_start in range(0, N_train, batch_size): + batch_idx = indices[batch_start: batch_start + batch_size] + + embs = [Z_train[i] for i in batch_idx] + max_ttok = max(e.shape[0] for e in embs) + D = embs[0].shape[1] + B = len(embs) + + memory = torch.zeros(B, max_ttok, D, dtype=torch.float32) + for bi, e in enumerate(embs): + memory[bi, : e.shape[0]] = e + memory = memory.to(device) + + y_batch = torch.tensor( + np.stack([y_train[i] for i in batch_idx]), + dtype=torch.float32, + device=device, + ) # (B, N, 2+k) + + optimizer.zero_grad() + with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=use_amp): + pos_logits, cls_logits = head(memory) + loss = head.compute_loss( + pos_logits, cls_logits, y_batch, lambda_cls=lambda_cls + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(head.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + + epoch_loss += loss.item() + num_batches += 1 + + print( + f" Epoch {epoch + 1:3d}/{num_epochs} | " + f"step {num_batches:4d}/{num_batches_per_epoch} | " + f"loss={loss.item():.4f}", + end="\r", + ) + + scheduler.step() + + avg = epoch_loss / num_batches + lr_now = scheduler.get_last_lr()[0] + print( + f" Epoch {epoch + 1:3d}/{num_epochs} | " + f"loss={avg:.4f} | lr={lr_now:.2e}" + + " " * 20 # clear leftover step line + ) + + head.eval() + return head diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py index 759dd6b..12bb596 100644 --- a/benchmark_utils/metrics.py +++ b/benchmark_utils/metrics.py @@ -496,6 +496,372 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): return float(np.mean(valid)) if valid else float("nan") +def _collect_tp_pairs(gt_by_series, preds, iou_threshold): + """Greedy IoU matching. Returns list of ([s,w]_pred, [s,w]_gt) for each TP.""" + matched = {i: [False] * len(boxes) for i, boxes in gt_by_series.items()} + pairs = [] + for i, s, w, _ in preds: + best_iou, best_gi = 0.0, -1 + for gi, (gs, gw) in enumerate(gt_by_series.get(i, [])): + iou = _iou_1d(s, w, gs, gw) + if iou > best_iou: + best_iou, best_gi = iou, gi + if best_iou >= iou_threshold and best_gi >= 0 and not matched[i][best_gi]: + gs, gw = gt_by_series[i][best_gi] + pairs.append(([s, w], [gs, gw])) + matched[i][best_gi] = True + return pairs + + +def _match_events(gt_by_series, preds, iou_threshold): + """Greedy IoU matching (highest-score first). Returns (tp, fp, fn).""" + matched = {i: [False] * len(boxes) for i, boxes in gt_by_series.items()} + n_gt = sum(len(b) for b in gt_by_series.values()) + tp = fp = 0 + for i, s, w, _ in preds: + best_iou, best_gi = 0.0, -1 + for gi, (gs, gw) in enumerate(gt_by_series.get(i, [])): + iou = _iou_1d(s, w, gs, gw) + if iou > best_iou: + best_iou, best_gi = iou, gi + if best_iou >= iou_threshold and best_gi >= 0 and not matched[i][best_gi]: + tp += 1 + matched[i][best_gi] = True + else: + fp += 1 + return tp, fp, n_gt - tp + + +def f1_det(y_true, y_pred, iou_threshold=0.5, score_threshold=None): + """Macro F1 for event detection. + + Parameters + ---------- + y_true, y_pred : same format as map_iou + iou_threshold : minimum IoU to count a match (default 0.5) + score_threshold: fixed class score threshold; if None, sweep to maximise F1 + per class (oracle — for benchmarking purposes only) + """ + if not y_true: + return float("nan") + + n_classes = y_true[0].shape[1] - 2 + f1s = [] + + for k in range(n_classes): + gt_by_series = {} + n_gt = 0 + for i, gt in enumerate(y_true): + boxes = [(row[0], row[1]) for row in gt + if len(gt) > 0 and np.argmax(row[2:]) == k] + gt_by_series[i] = boxes + n_gt += len(boxes) + + if n_gt == 0: + f1s.append(float("nan")) + continue + + all_preds = [] + for i, pred in enumerate(y_pred): + for row in pred: + all_preds.append((i, row[0], row[1], float(row[2 + k]))) + all_preds.sort(key=lambda x: -x[3]) + + if score_threshold is None: + scores = np.array([p[3] for p in all_preds]) + thresholds = ( + np.percentile(scores, np.arange(0, 100, 1)) + if len(scores) else [0.5] + ) + best_f1 = 0.0 + for thr in thresholds: + preds = [p for p in all_preds if p[3] >= thr] + tp, fp, fn = _match_events(gt_by_series, preds, iou_threshold) + denom = 2 * tp + fp + fn + best_f1 = max(best_f1, (2 * tp / denom) if denom > 0 else 0.0) + f1s.append(best_f1) + else: + preds = [p for p in all_preds if p[3] >= score_threshold] + tp, fp, fn = _match_events(gt_by_series, preds, iou_threshold) + denom = 2 * tp + fp + fn + f1s.append((2 * tp / denom) if denom > 0 else 0.0) + + valid = [f for f in f1s if not np.isnan(f)] + return float(np.mean(valid)) if valid else float("nan") +# _span_iou is an alias kept consistent with _iou_1d above + + +def _span_iou(start_a, len_a, start_b, len_b): + """Intersection-over-Union for two [start, start+length) intervals.""" + return _iou_1d(start_a, len_a, start_b, len_b) + + +def event_span_iou(y_true, y_pred, iou_threshold=0.5): + """Mean span IoU for event detection with greedy matching. + + Computes precision/recall/F1 of event spans across all series using greedy + IoU matching. A predicted span is a true positive if its IoU with an + unmatched ground-truth span exceeds ``iou_threshold``. + + Parameters + ---------- + y_true : List[np.ndarray (N=10, 2+k)] + Ground-truth padded event targets. All-zero rows = empty slots. + y_pred : List[np.ndarray (N=10, 2+k)] + Predicted outputs. Positions (cols 0-1) in [0,1]; class probs in [0,1]. + iou_threshold : float + Minimum IoU to count as a correct span detection (default 0.5). + + Returns + ------- + float — span F1 score averaged over all series + """ + f1_scores = [] + for gt, pr in zip(y_true, y_pred): + gt = np.asarray(gt) + pr = np.asarray(pr) + + gt_mask = gt[:, 2:].sum(axis=1) > 0 + pr_mask = pr[:, 2:].max(axis=1) > 0.5 + + gt_spans = gt[gt_mask, :2] + pr_spans = pr[pr_mask, :2] + + G = gt_spans.shape[0] + P = pr_spans.shape[0] + + if G == 0 and P == 0: + f1_scores.append(1.0) + continue + if G == 0 or P == 0: + f1_scores.append(0.0) + continue + + matched_gt = set() + tp = 0 + for pi in range(P): + best_iou = 0.0 + best_gi = -1 + for gi in range(G): + if gi in matched_gt: + continue + iou = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + if iou > best_iou: + best_iou = iou + best_gi = gi + if best_iou >= iou_threshold and best_gi >= 0: + matched_gt.add(best_gi) + tp += 1 + + precision = tp / P if P > 0 else 0.0 + recall = tp / G if G > 0 else 0.0 + f1 = (2 * precision * recall / (precision + recall) + if precision + recall > 0 else 0.0) + f1_scores.append(f1) + + return float(np.mean(f1_scores)) + + +def _match_spans(gt_spans, pr_spans, iou_threshold, matching_strategy): + """Match predicted spans to ground-truth spans by 1-D IoU. + + Each ground-truth span and each prediction is matched at most once; pairs + whose IoU is below ``iou_threshold`` are discarded so that unmatched and + duplicate predictions are left out of the matching. + + Parameters + ---------- + gt_spans : np.ndarray (G, 2+k) + Ground-truth spans (cols 0-1 = start/width, cols 2: = class columns). + pr_spans : np.ndarray (P, 2+k) + Predicted spans, same layout; class columns hold scores in [0, 1]. + iou_threshold : float + Minimum IoU to accept a match. + matching_strategy : {"greedy", "hungarian"} + ``"greedy"`` sorts predictions by descending confidence and assigns + each one to its highest-IoU still-free ground-truth span. + ``"hungarian"`` solves the maximum-IoU bipartite assignment globally. + + Returns + ------- + List[Tuple[int, int]] + ``(gi, pi)`` index pairs of matched ground-truth and prediction spans. + """ + G = gt_spans.shape[0] + P = pr_spans.shape[0] + if G == 0 or P == 0: + return [] + + # IoU matrix: rows = predictions, cols = ground-truth spans. + iou = np.zeros((P, G)) + for pi in range(P): + for gi in range(G): + iou[pi, gi] = _span_iou( + pr_spans[pi, 0], pr_spans[pi, 1], + gt_spans[gi, 0], gt_spans[gi, 1], + ) + + if matching_strategy == "greedy": + # Sort predictions by descending confidence (max class score) so that + # the most confident prediction claims its best ground-truth first. + order = np.argsort(-pr_spans[:, 2:].max(axis=1)) + matched_gt = set() + pairs = [] + for pi in order: + best_iou = 0.0 + best_gi = -1 + for gi in range(G): + if gi in matched_gt: + continue + if iou[pi, gi] > best_iou: + best_iou = iou[pi, gi] + best_gi = gi + # Accept the match only if the best free ground-truth clears the + # threshold; the chosen ground-truth is then marked as occupied. + if best_gi >= 0 and best_iou >= iou_threshold: + matched_gt.add(best_gi) + pairs.append((best_gi, pi)) + return pairs + + if matching_strategy == "hungarian": + # Maximum-weight bipartite matching on the IoU matrix (negated for the + # minimisation solver), keeping only pairs above the threshold. + from scipy.optimize import linear_sum_assignment + + pr_idx, gt_idx = linear_sum_assignment(-iou) + return [(int(gi), int(pi)) for pi, gi in zip(pr_idx, gt_idx) + if iou[pi, gi] >= iou_threshold] + + raise ValueError( + f"Unknown matching_strategy={matching_strategy!r}; " + "expected 'greedy' or 'hungarian'." + ) + + +def _f1_from_class_counts(tp, fp, fn, mode): + """Combine per-class TP/FP/FN counts into a single F1 score. + + Parameters + ---------- + tp, fp, fn : np.ndarray (k,) + Per-class true-positive, false-positive and false-negative counts. + mode : {"micro", "macro"} + ``"micro"`` pools counts across classes before computing one F1. + ``"macro"`` computes per-class F1 then averages them equally. + """ + if mode == "micro": + tp_s, fp_s, fn_s = tp.sum(), fp.sum(), fn.sum() + precision = tp_s / (tp_s + fp_s) if (tp_s + fp_s) > 0 else 0.0 + recall = tp_s / (tp_s + fn_s) if (tp_s + fn_s) > 0 else 0.0 + if precision + recall > 0: + return float(2 * precision * recall / (precision + recall)) + return 0.0 + + if mode == "macro": + f1s = [] + for k in range(len(tp)): + denom_p = tp[k] + fp[k] + denom_r = tp[k] + fn[k] + precision = tp[k] / denom_p if denom_p > 0 else 0.0 + recall = tp[k] / denom_r if denom_r > 0 else 0.0 + if precision + recall > 0: + f1s.append(2 * precision * recall / (precision + recall)) + else: + f1s.append(0.0) + return float(np.mean(f1s)) if f1s else 0.0 + + raise ValueError( + f"Unknown mode={mode!r}; expected 'micro' or 'macro'." + ) + + +def event_iou_f1(y_true, y_pred, iou_threshold=0.5, + matching_strategy="greedy", mode="micro"): + """F1 over binary class columns on IoU-matched event spans. + + Predicted spans are matched to ground-truth spans by 1-D IoU using the + requested strategy. For each matched pair, class probabilities are + thresholded at 0.5 and contribute true positives / false positives / false + negatives over the k binary class columns. Unmatched ground-truth spans + count as false negatives for all their active classes, while unmatched and + duplicate predictions count as false positives for all their active + classes, so both kinds of error are penalised. + + Parameters + ---------- + y_true : List[np.ndarray (N, 2+k)] + Ground-truth padded event targets. All-zero class columns = empty slot. + y_pred : List[np.ndarray (N, 2+k)] + Predicted outputs. Cols 0-1 = start/width, cols 2: = class scores. + iou_threshold : float + Minimum IoU to accept a span match (default 0.5). + matching_strategy : {"greedy", "hungarian"} + Span matching strategy (default "greedy"). See :func:`_match_spans`. + mode : {"micro", "macro"} + Class-averaging mode for the final F1 (default "micro"). + + Returns + ------- + float — class F1 score aggregated over all series. + """ + # Infer the number of class columns from the first 2-D sample available. + n_classes = None + for arr in list(y_true) + list(y_pred): + a = np.asarray(arr) + if a.ndim == 2 and a.shape[1] > 2: + n_classes = a.shape[1] - 2 + break + if n_classes is None: + return float("nan") + + tp = np.zeros(n_classes) + fp = np.zeros(n_classes) + fn = np.zeros(n_classes) + + for gt, pr in zip(y_true, y_pred): + gt = np.asarray(gt) + pr = np.asarray(pr) + + # Keep only occupied ground-truth slots and confident predictions. + gt_mask = gt[:, 2:].sum(axis=1) > 0 + pr_mask = pr[:, 2:].max(axis=1) > 0.5 + + gt_spans = gt[gt_mask] + pr_spans = pr[pr_mask] + + G = gt_spans.shape[0] + P = pr_spans.shape[0] + + pairs = _match_spans(gt_spans, pr_spans, iou_threshold, + matching_strategy) + matched_gt = {gi for gi, _ in pairs} + matched_pr = {pi for _, pi in pairs} + + # Matched pairs: per-class agreement on the thresholded class columns. + for gi, pi in pairs: + gt_cls = (gt_spans[gi, 2:] > 0.5).astype(int) + pr_cls = (pr_spans[pi, 2:] > 0.5).astype(int) + tp += gt_cls & pr_cls + fp += (1 - gt_cls) & pr_cls + fn += gt_cls & (1 - pr_cls) + + # Unmatched ground-truth spans: every active class is a false negative. + for gi in range(G): + if gi not in matched_gt: + fn += (gt_spans[gi, 2:] > 0.5).astype(int) + + # Unmatched / duplicate predictions: every active class is a false + # positive (penalty for over-prediction). + for pi in range(P): + if pi not in matched_pr: + fp += (pr_spans[pi, 2:] > 0.5).astype(int) + + return _f1_from_class_counts(tp, fp, fn, mode) + + # --------------------------------------------------------------------------- # Registry: maps metric name → function # --------------------------------------------------------------------------- @@ -529,7 +895,14 @@ def map_iou(y_true, y_pred, iou_threshold=0.5): EVENT_METRICS = { "map_iou": map_iou, + "f1_det": f1_det, + "event_span_iou": event_span_iou, + "event_iou_f1": event_iou_f1, } -ALL_METRICS = {**FORECASTING_METRICS, **CLASSIFICATION_METRICS, **AD_METRICS, - **EVENT_METRICS} +ALL_METRICS = { + **FORECASTING_METRICS, + **CLASSIFICATION_METRICS, + **AD_METRICS, + **EVENT_METRICS, +} diff --git a/datasets/mitdb.py b/datasets/mitdb.py index f8a06f5..4d392cd 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -9,17 +9,28 @@ class 2 V — Ventricular ectopic class 3 F — Fusion class 4 Q — Unknown / pacemaker artefact +Windowing +--------- +Each record is split into train/test portions and then sliced into fixed-size +overlapping windows (default: window_size=512 samples, overlap=50% → stride 256). +Only beat events whose full span [R-peak − beat_window, R-peak + beat_window) +lies **entirely within** the window are included. Events that straddle a +window boundary are discarded. Positions are normalised to [0, 1] over the +window length. + Data contract output -------------------- -X_train : List[np.ndarray (T_i, 2)] training portions (C == 2) -y_train : List[np.ndarray (N_i, 2+K)] one row per beat event: - col 0 start (normalised, 0–1) - col 1 width (normalised, 0–1) - cols 2… one-hot class vector (K) -X_test : List[np.ndarray (T_j, 2)] test portions -y_test : List[np.ndarray (N_j, 2+K)] same format +X_train : List[np.ndarray (512, 2)] one window per element (C == 2) +y_train : List[np.ndarray (N, 2+K)] one row per beat event, zero-padded + to N = max events across all windows: + col 0 start (normalised, 0–1) + col 1 width (normalised, 0–1) + cols 2… one-hot class vector (K) + all-zero row = empty/padding slot +X_test : List[np.ndarray (512, 2)] test windows +y_test : List[np.ndarray (N, 2+K)] same format, same N task : "event_detection" -metrics : ["map_iou"] +metrics : ["map_iou", "event_iou_f1"] extra : n_classes (int) K above """ @@ -116,9 +127,76 @@ def _annotations_to_events(n_samples, ann_samples, ann_symbols, beat_window, return np.array(rows, dtype=np.float32) +def _extract_window_events(ann_samples, ann_symbols, w_start, window_size, + beat_window, n_classes): + """Extract events that are fully contained within a single window. + + Each event is represented as ``(start, length)`` normalised to [0, 1] + over ``window_size``. An event centred at R-peak position ``s`` + (in segment coordinates) has: + + event_start = s - beat_window + event_length = 2 * beat_window (constant for all beats) + + The event is included **only if** its full span lies within the window: + + event_start >= w_start AND + event_start + event_length <= w_start + window_size + + Events that straddle a window boundary are discarded entirely. + + Parameters + ---------- + ann_samples : np.ndarray (A,) int annotation positions in segment coords + ann_symbols : list of str annotation symbols (len A) + w_start : int window start position in segment coords + window_size : int window length in samples (e.g. 512) + beat_window : int half-width of each event box in samples + n_classes : int K — number of AAMI classes + + Returns + ------- + events : np.ndarray (E, 2+K) float32 + Each row: [start_norm, length_norm, *one_hot_class] where + start_norm = (event_start - w_start) / window_size in [0, 1] + length_norm = event_length / window_size in [0, 1] + E=0 if no events are fully contained in the window. + """ + event_length = 2 * beat_window # constant: full span of every beat box + rows = [] + for sample, symbol in zip(ann_samples, ann_symbols): + aami_class = BEAT_CLASS.get(symbol) + if aami_class is None: + continue + class_idx = 0 if n_classes == 1 else aami_class - 1 + if class_idx >= n_classes: + continue + + event_start = sample - beat_window + + # Discard events whose span crosses a window boundary + if event_start < w_start or event_start + event_length > w_start + window_size: + continue + + one_hot = np.zeros(n_classes, dtype=np.float32) + one_hot[class_idx] = 1.0 + start_norm = (event_start - w_start) / window_size + length_norm = event_length / window_size + rows.append([start_norm, length_norm, *one_hot]) + + if not rows: + return np.zeros((0, 2 + n_classes), dtype=np.float32) + return np.array(rows, dtype=np.float32) + + class Dataset(BaseDataset): """MIT-BIH Arrhythmia Database for 1-D event detection. + Each record is split chronologically into train/test portions, then each + portion is sliced into fixed-size overlapping windows. Only beat events + that fit entirely within a window are kept; boundary-straddling events are + discarded. Positions are normalised to [0, 1] over ``window_size``. + Parameters ---------- record_ids : list of str or "all" @@ -134,6 +212,11 @@ class Dataset(BaseDataset): K — number of AAMI beat classes to distinguish (1–5). Classes are ordered N, S, V, F, Q; setting n_classes=1 collapses all annotated beats into a single "beat" class. + window_size : int + Length of each window in samples (default 512). + window_overlap : float + Fractional overlap between consecutive windows in [0, 1). + Default 0.5 → stride = window_size // 2 = 256 samples. """ name = "MITDB" @@ -146,6 +229,8 @@ class Dataset(BaseDataset): "train_ratio": [0.7], "beat_window": [36], "n_classes": [5], + "window_size": [512], + "window_overlap": [0.5], } def get_data(self): @@ -155,6 +240,8 @@ def get_data(self): if self.debug: record_ids = record_ids[:2] + stride = max(1, int(self.window_size * (1.0 - self.window_overlap))) + X_train, y_train, X_test, y_test = [], [], [], [] for rid in record_ids: signal, ann_samples, ann_symbols = _load_record(rid, data_dir) @@ -167,25 +254,53 @@ def get_data(self): split = max(1, int(len(signal) * self.train_ratio)) - for seg_signal, start, end, Xl, yl in [ + for seg_signal, seg_start, seg_end, Xl, yl in [ (signal[:split], 0, split, X_train, y_train), (signal[split:], split, len(signal), X_test, y_test), ]: - seg_ann = ann_samples[(ann_samples >= start) & (ann_samples < end)] - start + # Annotations relative to the segment (0-based within seg_signal) + seg_ann = (ann_samples[(ann_samples >= seg_start) & (ann_samples < seg_end)] + - seg_start) seg_sym = [s for s, idx in zip(ann_symbols, ann_samples) - if start <= idx < end] - Xl.append(seg_signal) - yl.append(_annotations_to_events( - len(seg_signal), seg_ann, seg_sym, - self.beat_window, self.n_classes, - )) + if seg_start <= idx < seg_end] + + seg_len = len(seg_signal) + if seg_len < self.window_size: + # Segment too short for even one window — skip + continue + + for w_start in range(0, seg_len - self.window_size + 1, stride): + window = seg_signal[w_start: w_start + self.window_size] + events = _extract_window_events( + seg_ann, seg_sym, + w_start, self.window_size, + self.beat_window, self.n_classes, + ) + Xl.append(window) + yl.append(events) + + # Pad all event arrays to a uniform N so solvers can np.stack them. + # N = max events in any single window across train and test. + all_y = y_train + y_test + max_n = max((y.shape[0] for y in all_y), default=0) + n_cols = 2 + self.n_classes + + def _pad(arrays): + out = [] + for y in arrays: + n = y.shape[0] + if n < max_n: + pad = np.zeros((max_n - n, n_cols), dtype=np.float32) + y = np.concatenate([y, pad], axis=0) + out.append(y) + return out return dict( X_train=X_train, - y_train=y_train, + y_train=_pad(y_train), X_test=X_test, - y_test=y_test, + y_test=_pad(y_test), task="event_detection", - metrics=["map_iou"], + metrics=["map_iou", "event_iou_f1"], n_classes=self.n_classes, ) diff --git a/objective.py b/objective.py index 6591178..923ee27 100644 --- a/objective.py +++ b/objective.py @@ -1,8 +1,8 @@ """ Unified objective for the TSFM benchmark. -Supports three tasks — forecasting, classification, anomaly detection — -dispatched via the ``task`` field provided by each dataset. +Supports four tasks — forecasting, classification, anomaly detection, and +event detection — dispatched via the ``task`` field provided by each dataset. Data contract ------------- @@ -35,9 +35,15 @@ extra n_classes (int) anomaly_detection y_train None y_test List[(T_j,)] int point-level binary labels -event_detection y_train List[(N_i, 2+K)] float object-detection boxes - y_test List[(N_j, 2+K)] float object-detection boxes - extra n_classes (int) +event_detection y_train List[(N, 2+K)] float object-detection boxes, + zero-padded to uniform N + y_test List[(N, 2+K)] float same format, same N + extra n_classes (int), T (int, default 512) + + Each row of the (N, 2+k) array is one event slot: + col 0 : start, normalised to [0, 1] over T + col 1 : length, normalised to [0, 1] over T + col 2.. : k binary class columns (all-zero row = empty slot) Solver contract --------------- @@ -169,26 +175,34 @@ def _eval_classification(self, model): result[name] = ALL_METRICS[name](y_true, y_pred) return result - # --- event detection ----------------------------------------------- + # --- anomaly detection --------------------------------------------- - def _eval_event_detection(self, model): - # model.predict returns (N, 2+K) float array per series - preds = [np.asarray(model.predict(x)) for x in self.X_test] + def _eval_anomaly_detection(self, model): + # model.predict returns (T_j,) float scores per series + scores = [np.asarray(model.predict(x)) for x in self.X_test] result = {} for name in self.metrics: - result[name] = ALL_METRICS[name](self.y_test, preds) + result[name] = ALL_METRICS[name](self.y_test, scores) return result - # --- anomaly detection --------------------------------------------- + # --- event detection ----------------------------------------------- - def _eval_anomaly_detection(self, model): - # model.predict returns (T_j,) float scores per series - scores = [np.asarray(model.predict(x)) for x in self.X_test] + def _eval_event_detection(self, model): + """Evaluate event detection. + + model.predict(x) must return (N, 2+k) float array per series: + col 0 : predicted start in [0, 1] + col 1 : predicted length in [0, 1] + col 2.. : binary class probabilities / scores in [0, 1] + + y_test is a List of (N, 2+k) event arrays (variable or padded). + """ + preds = [np.asarray(model.predict(x)) for x in self.X_test] result = {} for name in self.metrics: - result[name] = ALL_METRICS[name](self.y_test, scores) + result[name] = ALL_METRICS[name](self.y_test, preds) return result # ------------------------------------------------------------------ diff --git a/solvers/chronos.py b/solvers/chronos.py index 1a507c1..6e9f763 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -4,9 +4,10 @@ - forecasting : zero-shot via ChronosPipeline - classification : linear probe on pooled encoder embeddings - anomaly_detection : forecast-residual on top of the same forecaster + - event_detection : EventHead trained on frozen encoder embeddings Model loading is done in ``set_objective`` (untimed). Inference batches -every (series, cutoff) pair into a single ``ChronosPipeline.predict`` +every (series, cutoff) pair into a single ``Chronos2Pipeline.predict`` call — the pipeline accepts a list of variable-length tensors and applies left-padding internally, so all the per-cutoff work happens in one forward pass. @@ -18,7 +19,6 @@ import numpy as np import torch -from benchopt import BaseSolver from chronos import ChronosPipeline from benchmark_utils.adapters import ( @@ -30,11 +30,20 @@ UnpooledEncoder, ) from benchmark_utils.adapters.base import BaseTSFMAdapter +from benchmark_utils.adapters.event_detection import ( + ChronosEventAdapter, + fit_event_head, + precompute_embeddings, +) from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter +from benchmark_utils.base_solver import BaseTSFMSolver from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput -SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection"} +SUPPORTED_TASKS = {"forecasting", "classification", "anomaly_detection", "event_detection"} + +# Chronos encoder output dimension by model size +_CHRONOS_D = {"tiny": 64, "mini": 128, "small": 512, "base": 768, "large": 1024} POOLERS = { "mean": MeanPooler, @@ -43,6 +52,14 @@ } +def _to_context(x): + """Reshape ``(T, V)`` or ``(B, T, V)`` to Chronos input ``(B, V, T)``.""" + x = np.asarray(x, dtype=np.float32) + if x.ndim == 2: + x = x[None] + return x.transpose(0, 2, 1) + + class _ChronosForecaster(BaseTSFMAdapter): """Batched Chronos v1 adapter; quantiles are derived from sample draws.""" @@ -111,29 +128,22 @@ def _assemble_output(self, samples, layout, per_series_shape, prediction_length) class _ChronosEmbedEncoder(UnpooledEncoder): - """Default path — uses ``ChronosPipeline.embed``. + """Default path — uses ``Chronos2Pipeline.embed``. - Returns hidden states *after* ``encoder.final_layer_norm``. + Returns hidden states *after* ``encoder.final_layer_norm`` for each + series in the batch. """ - def __init__(self, pipeline: ChronosPipeline): + def __init__(self, pipeline): self.pipeline = pipeline def encode(self, X) -> np.ndarray: - # X: (B, T, V) or (T, V). - X = np.asarray(X, dtype=np.float32) - batched = X.ndim == 3 - if not batched: - X = X[None] # (1, T, V) - B, T, V = X.shape - - # Chronos is univariate — flatten B & V into the batch axis. - flat = X.reshape(B * V, T) # (B*V, T) + context = _to_context(X) # (B, V, T) with torch.no_grad(): - emb, _ = self.pipeline.embed(torch.from_numpy(flat)) # (B*V, T_tok, D) - - # (B*V, T_tok, D) -> (B, T_tok, V, D) - return emb.float().cpu().numpy().reshape(B, -1, V, emb.shape[-1]) + # embed returns a list of B tensors, each of shape (V, T, D). + embeddings, _ = self.pipeline.embed(context) + stacked = torch.stack(list(embeddings)) # (B, V, T, D) + return stacked.transpose(1, 2).float().cpu().numpy() # (B, T, V, D) class _ChronosHookEncoder(UnpooledEncoder): @@ -143,7 +153,7 @@ class _ChronosHookEncoder(UnpooledEncoder): indices are allowed (``-1`` = last block). """ - def __init__(self, pipeline: ChronosPipeline, layer: int): + def __init__(self, pipeline, layer: int): self.pipeline = pipeline n_blocks = len(pipeline.model.model.encoder.block) if not -n_blocks <= layer < n_blocks: @@ -152,15 +162,8 @@ def __init__(self, pipeline: ChronosPipeline, layer: int): ) self._block_idx = layer % n_blocks - def encode(self, X) -> np.ndarray: - X = np.asarray(X, dtype=np.float32) - batched = X.ndim == 3 - if not batched: - X = X[None] # (1, T, V) - B, T, V = X.shape - - flat = X.reshape(B * V, T) # (B*V, T) - context = torch.from_numpy(flat) + def encode(self, x: np.ndarray) -> np.ndarray: + context = _to_context(x) # (B, V, T) token_ids, attn_mask, _ = self.pipeline.tokenizer.context_input_transform( context ) @@ -172,6 +175,8 @@ def encode(self, X) -> np.ndarray: captured = {} def _hook(_module, _inputs, output): + # Hook to capture the embeddings while performing a forward pass + # T5Block returns a tuple; first element is the hidden state. hidden = output[0] if isinstance(output, tuple) else output captured["h"] = hidden.detach() @@ -182,19 +187,20 @@ def _hook(_module, _inputs, output): finally: handle.remove() - # (B*V, T_tok, D) -> (B, T_tok, V, D) + # (B*V, T_tok, D) -> (B, T_tok, V, D) matching _ChronosEmbedEncoder output shape + B, V = context.shape[:2] + D = captured["h"].shape[-1] return ( captured["h"] + .reshape(B, V, -1, D) + .permute(0, 2, 1, 3) .float() .cpu() .numpy() - .reshape(B, -1, V, captured["h"].shape[-1]) ) -def ChronosEncoder( - pipeline: ChronosPipeline, layer: int | None = None -) -> UnpooledEncoder: +def ChronosEncoder(pipeline, layer: int | None = None) -> UnpooledEncoder: """Build a Chronos feature extractor. Parameters @@ -230,8 +236,8 @@ def ChronosEncoder( # --------------------------------------------------------------------------- -class Solver(BaseSolver): - """Chronos zero-shot solver. +class Solver(BaseTSFMSolver): + """Chronos-2 zero-shot solver. Parameters ---------- @@ -242,22 +248,29 @@ class Solver(BaseSolver): ``ChronosPipeline.embed`` (post-final-norm). pooler : {"mean", "max", "last"} Pooling strategy over the time-token axis for classification. - task_adaptation : str - Per-task usage of the forecaster: - ``"zeroshot"`` — direct forecasting (forecasting only) - ``"forecast_residual"`` — anomaly score = forecast error (AD only) + model_path : str + Local directory path to load the Chronos model from. When empty + (default), the model is loaded from HuggingFace Hub. """ name = "Chronos" - requirements = ["pip::chronos-forecasting>=2.2", "pip::torch"] - - sampling_strategy = "run_once" + requirements = ["pip::chronos-forecasting>=2.2,<3"] parameters = { "model_size": ["small"], "layer": [None], "pooler": ["mean"], + # event_detection — single values so no cross-product for other tasks + "model_path": [""], + "batch_size": [64], + "num_epochs": [3], + "lr": [3e-4], + "weight_decay": [1e-4], + "warmup_epochs": [1], + "num_dec_layers": [2], + "lambda_cls": [1.0], + "num_queries": [10], "classifier": ["log_reg"], "penalty": ["l2"], "C": [1.0], @@ -265,37 +278,144 @@ class Solver(BaseSolver): "n_iterators": [100], } - def skip(self, task, **kwargs): - if task not in SUPPORTED_TASKS: - return True, f"Chronos solver does not support task={task!r}" - return False, None - - def set_objective(self, X_train, y_train, task, **meta): - self.task = task - self.X_train = X_train - self.y_train = y_train - self.meta = meta - - # bfloat16 is fine on CUDA but poorly supported on CPU / MPS; - # fall back to float32 there so inference doesn't crash or stall. - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = torch.bfloat16 if device == "cuda" else torch.float32 - model_id = f"amazon/chronos-t5-{self.model_size}" + def __init__( + self, + model_size="small", + layer=None, + pooler="mean", + model_path="", + batch_size=32, + num_epochs=100, + lr=3e-4, + weight_decay=1e-4, + warmup_epochs=5, + num_dec_layers=2, + lambda_cls=1.0, + num_queries=10, + classifier="log_reg", + penalty="l2", + C=1.0, + alpha=1.0, + n_iterators=100, + ): + """Initialize Chronos-specific state. + + Parameters + ---------- + model_size : str, default="small" + Chronos model variant to load. + layer : int or None, default=None + Encoder block index for classification embeddings. + pooler : {"mean", "max", "last"}, default="mean" + Pooling strategy over the time-token axis for classification. + model_path : str, default="" + Local model directory; empty = load from HuggingFace Hub. + """ + super().__init__( + model_size=model_size, + layer=layer, + pooler=pooler, + model_path=model_path, + batch_size=batch_size, + num_epochs=num_epochs, + lr=lr, + weight_decay=weight_decay, + warmup_epochs=warmup_epochs, + num_dec_layers=num_dec_layers, + lambda_cls=lambda_cls, + num_queries=num_queries, + classifier=classifier, + penalty=penalty, + C=C, + alpha=alpha, + n_iterators=n_iterators, + ) + self._pipeline = None + self._loaded_model = None + + @property + def supported_tasks(self): + return SUPPORTED_TASKS + + def load_model(self, device, dtype): + """Load Chronos pipeline (cached if already loaded). + + Chronos-2 models (autogluon/chronos-2-*) are loaded via + ``Chronos2Pipeline``; Chronos-1 / T5-based models (e.g. a local + ``chronos_t5_*`` checkpoint) are loaded via ``ChronosPipeline``. + Detection is done by inspecting the ``architectures`` field of the + model config; if ``Chronos2Pipeline`` raises ``AttributeError`` the + loader falls back to ``ChronosPipeline`` automatically. + """ + from chronos import ChronosPipeline, Chronos2Pipeline + + model_id = f"autogluon/chronos-2-{self.model_size}" + model_id = self.model_path if self.model_path else model_id if not hasattr(self, "_pipeline") or self._loaded_model != model_id: - self._pipeline = ChronosPipeline.from_pretrained( - model_id, - device_map=device, - dtype=dtype, - ) + try: + self._pipeline = Chronos2Pipeline.from_pretrained( + model_id, + device_map=device, + dtype=dtype, + ) + except (AttributeError, ValueError): + # Chronos-1 / T5-based checkpoint — fall back to ChronosPipeline + self._pipeline = ChronosPipeline.from_pretrained( + model_id, + device_map=device, + dtype=dtype, + ) self._loaded_model = model_id + return self._pipeline - def run(self, _): + def set_objective(self, X_train, y_train, task, **meta): + """Load pipeline then pre-compute embeddings for event_detection.""" + super().set_objective(X_train, y_train, task, **meta) + + if task == "event_detection": + self._n_classes = int(meta["n_classes"]) + self._d_model = _CHRONOS_D.get(self.model_size, 512) + self._Z_train = precompute_embeddings(self.model, X_train, batch_size=self.batch_size) + + def forecast_batch(self, inputs): + """Chronos-specific batch prediction. + + Parameters + ---------- + inputs : list of torch.Tensor + Each tensor shape (C, T_cutoff) + + Returns + ------- + list of torch.Tensor + Each tensor shape (C, Q, H) + """ + with torch.no_grad(): + return self.model.predict(inputs, prediction_length=self.prediction_length) + + def build_adapter(self, task, model): + # TODO later: put that code in base_solver.py + # and make it rely on .forecast(), .embed() and .time_embed() only, once those are all properly coded + """Create task-specific adapter for Chronos.""" pred_len = self.meta.get("prediction_length", 1) - if self.task == "forecasting": - self._adapter = _ChronosForecaster(self._pipeline, pred_len) - elif self.task == "classification": - base_encoder = ChronosEncoder(self._pipeline, layer=self.layer) + if task == "forecasting": + self.prediction_length = pred_len + quantile_levels = tuple(float(q) for q in model.quantiles) + + # Create a simple adapter that calls self.forecast() + class _ForecastAdapter(BaseTSFMAdapter): + def __init__(self, solver, quantile_levels): + self.solver = solver + self.quantile_levels = quantile_levels + + def predict(self, x: ForecastInput) -> ForecastOutput: + return self.solver.forecast(x, self.solver.prediction_length, self.quantile_levels) + + return _ForecastAdapter(self, quantile_levels) + + elif task == "classification": + base_encoder = ChronosEncoder(model, layer=self.layer) encoder = Encoder(base_encoder, POOLERS[self.pooler]()) adapter = LinearProbeAdapter( encoder, @@ -303,7 +423,7 @@ def run(self, _): n_classes=self.meta.get("n_classes"), ) adapter.fit(self.X_train, self.y_train) - self._adapter = adapter + return adapter elif self.task == "anomaly_detection": # AD scores forecast residuals over an adaptive horizon. @@ -312,5 +432,26 @@ def run(self, _): _ChronosForecaster(self._pipeline, prediction_length=1), ) - def get_result(self): - return {"model": self._adapter} + # Create a forecaster adapter for residual-based anomaly detection + class _ForecasterForAD(BaseTSFMAdapter): + def __init__(self, solver, quantile_levels): + self.solver = solver + self.quantile_levels = quantile_levels + + def predict(self, x: ForecastInput) -> ForecastOutput: + return self.solver.forecast(x, 1, self.quantile_levels) + + forecaster = _ForecasterForAD(self, quantile_levels) + return ForecastResidualAdapter(forecaster, prediction_length=1) + + elif task == "event_detection": + device = "cuda" if torch.cuda.is_available() else "cpu" + head = fit_event_head( + self._Z_train, self.y_train, self._n_classes, self._d_model, + device, self.batch_size, self.num_epochs, self.lr, + self.weight_decay, self.warmup_epochs, self.num_dec_layers, + self.lambda_cls, self.num_queries, + ) + return ChronosEventAdapter(model, head, device, self._n_classes) + + raise ValueError(f"Unknown task: {task}") diff --git a/tests/benchmark_utils/test_metrics_map_iou.py b/tests/benchmark_utils/test_metrics_map_iou.py new file mode 100644 index 0000000..e7d5317 --- /dev/null +++ b/tests/benchmark_utils/test_metrics_map_iou.py @@ -0,0 +1,395 @@ +"""Unit tests for map_iou, f1_det and their helpers.""" + +import math + +import numpy as np +import pytest + +from benchmark_utils.metrics import ( + _ap_from_tp_fp, + _f1_from_class_counts, + _iou_1d, + _match_spans, + event_iou_f1, + event_span_iou, + f1_det, + map_iou, +) + + +def _evt(start, width, *class_cols): + """Single-row event array: [start, width, *class_cols].""" + return np.array([[start, width, *class_cols]], dtype=float) + + +def _no_evts(n_classes): + return np.zeros((0, 2 + n_classes)) + + +# --------------------------------------------------------------------------- +# _iou_1d +# --------------------------------------------------------------------------- + +def test_iou_identical(): + assert _iou_1d(0.0, 0.5, 0.0, 0.5) == pytest.approx(1.0) + +def test_iou_no_overlap(): + assert _iou_1d(0.0, 0.3, 0.4, 0.3) == pytest.approx(0.0) + +def test_iou_adjacent(): + # Segments touch at a single point — inter = 0 + assert _iou_1d(0.0, 0.5, 0.5, 0.5) == pytest.approx(0.0) + +def test_iou_partial(): + # [0, 1) and [0.5, 1.5) → inter=0.5, union=1.5 + assert _iou_1d(0.0, 1.0, 0.5, 1.0) == pytest.approx(0.5 / 1.5) + +def test_iou_contained(): + # [0, 1) contains [0.2, 0.6) → inter=0.4, union=1.0 + assert _iou_1d(0.0, 1.0, 0.2, 0.4) == pytest.approx(0.4 / 1.0) + +def test_iou_zero_width_returns_zero(): + assert _iou_1d(0.5, 0.0, 0.5, 0.0) == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# _ap_from_tp_fp +# --------------------------------------------------------------------------- + +def test_ap_no_gt_returns_nan(): + assert math.isnan(_ap_from_tp_fp(np.array([1.0]), np.array([0.0]), 0)) + +def test_ap_all_tp(): + ap = _ap_from_tp_fp(np.array([1.0, 1.0]), np.array([0.0, 0.0]), n_gt=2) + assert ap == pytest.approx(1.0) + +def test_ap_all_fp(): + ap = _ap_from_tp_fp(np.array([0.0, 0.0]), np.array([1.0, 1.0]), n_gt=2) + assert ap == pytest.approx(0.0) + +def test_ap_tp_then_fp(): + # Rank 1: TP (precision=1, recall=0.5), Rank 2: FP (recall still 0.5) + # AP = Δrecall × precision = 0.5 × 1.0 = 0.5 + ap = _ap_from_tp_fp(np.array([1.0, 0.0]), np.array([0.0, 1.0]), n_gt=2) + assert ap == pytest.approx(0.5) + + +# --------------------------------------------------------------------------- +# map_iou +# --------------------------------------------------------------------------- + +def test_map_empty_y_true_returns_nan(): + assert math.isnan(map_iou([], [])) + +def test_map_no_predictions_returns_zero(): + assert map_iou([_evt(0.0, 0.5, 1)], [_no_evts(1)]) == pytest.approx(0.0) + +def test_map_no_gt_events_returns_nan(): + # Series has no GT events — AP is undefined + assert math.isnan(map_iou([_no_evts(1)], [_evt(0.0, 0.5, 0.9)])) + +def test_map_perfect_match(): + assert map_iou([_evt(0.1, 0.4, 1)], [_evt(0.1, 0.4, 1.0)]) == pytest.approx(1.0) + +def test_map_no_overlap(): + assert map_iou([_evt(0.0, 0.3, 1)], [_evt(0.7, 0.3, 1.0)]) == pytest.approx(0.0) + +def test_map_overlap_below_threshold_is_miss(): + # IoU ≈ 0.333 < 0.5 + assert map_iou([_evt(0.0, 1.0, 1)], [_evt(0.5, 1.0, 1.0)], iou_threshold=0.5) == pytest.approx(0.0) + +def test_map_overlap_above_custom_threshold_is_hit(): + # Same IoU ≈ 0.333 > 0.3 + assert map_iou([_evt(0.0, 1.0, 1)], [_evt(0.5, 1.0, 1.0)], iou_threshold=0.3) == pytest.approx(1.0) + +def test_map_duplicate_pred_only_first_matched(): + # Two identical predictions for one GT: first TP, second FP → AP=1.0 + preds = [np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 0.9]])] + assert map_iou([_evt(0.0, 0.5, 1)], preds) == pytest.approx(1.0) + +def test_map_predictions_ranked_by_score(): + # High-score pred misses (FP at rank 1), low-score pred hits (TP at rank 2) + # AP = Δrecall × precision at rank 2 = 1.0 × 0.5 = 0.5 + preds = [np.array([[0.8, 0.1, 0.9], [0.0, 0.5, 0.5]])] + assert map_iou([_evt(0.0, 0.5, 1)], preds) == pytest.approx(0.5) + +def test_map_two_classes_both_perfect(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert map_iou([gt], [pred]) == pytest.approx(1.0) + +def test_map_two_classes_one_missed(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([ + [0.0, 0.3, 1.0, 0.0], # class 0: perfect hit + [0.0, 0.1, 0.0, 1.0], # class 1: no overlap with GT at [0.5, 0.3] + ]) + assert map_iou([gt], [pred]) == pytest.approx(0.5) + +def test_map_multi_series_both_matched(): + y_true = [_evt(0.0, 0.5, 1), _evt(0.2, 0.4, 1)] + y_pred = [_evt(0.0, 0.5, 1.0), _evt(0.2, 0.4, 1.0)] + assert map_iou(y_true, y_pred) == pytest.approx(1.0) + +def test_map_prediction_in_wrong_series_is_fp(): + # GT in series 0, pred only in series 1 — should not match + y_true = [_evt(0.0, 0.5, 1), _no_evts(1)] + y_pred = [_no_evts(1), _evt(0.0, 0.5, 1.0)] + assert map_iou(y_true, y_pred) == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# Helpers for padded-array format used by event_span_iou / event_iou_f1 +# --------------------------------------------------------------------------- + +def _series(events, n_slots=5, n_classes=2): + """Padded (n_slots, 2+n_classes) array from list of (start, width, *cols).""" + arr = np.zeros((n_slots, 2 + n_classes)) + for i, ev in enumerate(events): + arr[i] = ev + return arr + + +# --------------------------------------------------------------------------- +# _match_spans +# --------------------------------------------------------------------------- + +def test_match_spans_empty_gt(): + gt = np.zeros((0, 4)) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_empty_pred(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.zeros((0, 4)) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_perfect_match_greedy(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [(0, 0)] + +def test_match_spans_perfect_match_hungarian(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "hungarian") == [(0, 0)] + +def test_match_spans_below_threshold_no_match(): + # IoU([0,1), [0.5,1.5)) ≈ 0.33 < 0.5 + gt = np.array([[0.0, 1.0, 1.0, 0.0]]) + pr = np.array([[0.5, 1.0, 0.9, 0.0]]) + assert _match_spans(gt, pr, 0.5, "greedy") == [] + +def test_match_spans_greedy_vs_hungarian_differ(): + # pred0: IoU=1.0 with GT, score=0.6 + # pred1: IoU=0.8 with GT, score=0.9 + # Greedy: pred1 (higher score) goes first → matches GT → [(0, 1)] + # Hungarian: pred0 (higher IoU) is assigned → [(0, 0)] + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([ + [0.0, 0.5, 0.6, 0.0], # pred0: IoU=1.0, score=0.6 + [0.1, 0.4, 0.9, 0.0], # pred1: IoU=0.8, score=0.9 + ]) + assert _match_spans(gt, pr, 0.5, "greedy") == [(0, 1)] + assert _match_spans(gt, pr, 0.5, "hungarian") == [(0, 0)] + +def test_match_spans_duplicate_pred_only_one_matched(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0], [0.0, 0.5, 0.8, 0.0]]) + assert len(_match_spans(gt, pr, 0.5, "greedy")) == 1 + +def test_match_spans_two_gt_two_pred_both_matched(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pr = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert set(_match_spans(gt, pr, 0.5, "greedy")) == {(0, 0), (1, 1)} + +def test_match_spans_invalid_strategy_raises(): + gt = np.array([[0.0, 0.5, 1.0, 0.0]]) + pr = np.array([[0.0, 0.5, 0.9, 0.0]]) + with pytest.raises(ValueError): + _match_spans(gt, pr, 0.5, "invalid") + + +# --------------------------------------------------------------------------- +# _f1_from_class_counts +# --------------------------------------------------------------------------- + +def test_f1_counts_micro_perfect(): + assert _f1_from_class_counts(np.array([2, 1]), np.array([0, 0]), np.array([0, 0]), "micro") == pytest.approx(1.0) + +def test_f1_counts_micro_all_zeros(): + assert _f1_from_class_counts(np.array([0, 0]), np.array([0, 0]), np.array([0, 0]), "micro") == pytest.approx(0.0) + +def test_f1_counts_micro_mixed(): + # tp=1, fp=1, fn=1 → P=0.5, R=0.5 → F1=0.5 + assert _f1_from_class_counts(np.array([1]), np.array([1]), np.array([1]), "micro") == pytest.approx(0.5) + +def test_f1_counts_macro_mixed(): + # class 0: tp=1, fp=0, fn=0 → F1=1.0; class 1: tp=0, fp=1, fn=1 → F1=0 → mean=0.5 + assert _f1_from_class_counts(np.array([1, 0]), np.array([0, 1]), np.array([0, 1]), "macro") == pytest.approx(0.5) + +def test_f1_counts_micro_macro_differ(): + # class 0: tp=1,fp=0,fn=0 → F1=1.0 | class 1: tp=0,fp=0,fn=1 → F1=0 + # micro: tp_s=1,fp_s=0,fn_s=1 → P=1,R=0.5 → F1=2/3 + # macro: mean([1.0, 0.0]) = 0.5 + tp = np.array([1, 0]) + fp = np.array([0, 0]) + fn = np.array([0, 1]) + assert _f1_from_class_counts(tp, fp, fn, "micro") == pytest.approx(2 / 3) + assert _f1_from_class_counts(tp, fp, fn, "macro") == pytest.approx(0.5) + +def test_f1_counts_invalid_mode_raises(): + with pytest.raises(ValueError): + _f1_from_class_counts(np.array([1]), np.array([0]), np.array([0]), "invalid") + + +# --------------------------------------------------------------------------- +# event_span_iou +# --------------------------------------------------------------------------- + +def test_event_span_iou_both_empty(): + s = _series([]) + assert event_span_iou([s], [s]) == pytest.approx(1.0) + +def test_event_span_iou_pred_empty(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_gt_empty(): + gt = _series([]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_perfect_match(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(1.0) + +def test_event_span_iou_no_overlap(): + gt = _series([[0.0, 0.3, 1, 0]]) + pr = _series([[0.7, 0.3, 0.9, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_pred_below_score_threshold_ignored(): + # max class score 0.3 < 0.5 → pred filtered out → G=1, P=0 → F1=0 + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.3, 0.0]]) + assert event_span_iou([gt], [pr]) == pytest.approx(0.0) + +def test_event_span_iou_multi_series_averaged(): + # series 0: perfect → F1=1.0; series 1: no pred → F1=0.0; mean=0.5 + gt0, pr0 = _series([[0.0, 0.5, 1, 0]]), _series([[0.0, 0.5, 0.9, 0.0]]) + gt1, pr1 = _series([[0.2, 0.4, 1, 0]]), _series([]) + assert event_span_iou([gt0, gt1], [pr0, pr1]) == pytest.approx(0.5) + + +# --------------------------------------------------------------------------- +# event_iou_f1 +# --------------------------------------------------------------------------- + +def test_event_iou_f1_no_arrays_returns_nan(): + assert math.isnan(event_iou_f1([], [])) + +def test_event_iou_f1_perfect_match(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.1]]) + assert event_iou_f1([gt], [pr]) == pytest.approx(1.0) + +def test_event_iou_f1_unmatched_gt_is_fn(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([]) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_unmatched_pred_is_fp(): + gt = _series([]) + pr = _series([[0.0, 0.5, 0.9, 0.0]]) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_wrong_class_prediction(): + # Span matched by IoU but class assignment is swapped → class errors + # tp=[0,0], fp=[1,0], fn=[0,1] + # micro: tp_s=0, fp_s=1, fn_s=1 → F1=0 + gt = _series([[0.0, 0.5, 1, 0]]) # class 0 active + pr = _series([[0.0, 0.5, 0.4, 0.9]]) # predicts class 1 (max=0.9 > 0.5) + assert event_iou_f1([gt], [pr]) == pytest.approx(0.0) + +def test_event_iou_f1_micro_vs_macro_differ(): + # GT0=class0, GT1=class1; Pred0 correct, Pred1 wrong class + # After matching: (GT0,P0) tp=[1,0]; (GT1,P1) gt_cls=[0,1] pr_cls=[1,0] + # → tp+=[0,0], fp+=[1,0], fn+=[0,1] + # total: tp=[1,0], fp=[1,0], fn=[0,1] + # micro: tp_s=1,fp_s=1,fn_s=1 → P=0.5,R=0.5 → F1=0.5 + # macro: class0 P=0.5,R=1→F1=2/3; class1 P=0,R=0→F1=0 → mean=1/3 + gt = _series([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pr = _series([[0.0, 0.3, 0.9, 0.1], [0.5, 0.3, 0.6, 0.4]]) + assert event_iou_f1([gt], [pr], mode="micro") == pytest.approx(0.5) + assert event_iou_f1([gt], [pr], mode="macro") == pytest.approx(1 / 3) + +def test_event_iou_f1_hungarian_strategy(): + gt = _series([[0.0, 0.5, 1, 0]]) + pr = _series([[0.0, 0.5, 0.9, 0.1]]) + assert event_iou_f1([gt], [pr], matching_strategy="hungarian") == pytest.approx(1.0) + +def test_event_iou_f1_multi_series(): + gt0 = _series([[0.0, 0.5, 1, 0]]) + pr0 = _series([[0.0, 0.5, 0.9, 0.0]]) + gt1 = _series([[0.2, 0.4, 0, 1]]) + pr1 = _series([[0.2, 0.4, 0.1, 0.8]]) + assert event_iou_f1([gt0, gt1], [pr0, pr1]) == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# f1_det +# --------------------------------------------------------------------------- + +def test_f1det_empty_y_true_returns_nan(): + assert math.isnan(f1_det([], [])) + +def test_f1det_no_gt_events_returns_nan(): + assert math.isnan(f1_det([_no_evts(1)], [_evt(0.0, 0.5, 0.9)])) + +def test_f1det_perfect_match_fixed_threshold(): + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.9)], score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_no_overlap_fixed_threshold(): + # TP=0, FP=1, FN=1 → F1=0 + assert f1_det([_evt(0.0, 0.3, 1)], [_evt(0.7, 0.3, 0.9)], score_threshold=0.5) == pytest.approx(0.0) + +def test_f1det_prediction_below_threshold_becomes_fn(): + # Pred score 0.3 < threshold 0.5 → filtered out → TP=0, FP=0, FN=1 → F1=0 + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.3)], score_threshold=0.5) == pytest.approx(0.0) + +def test_f1det_oracle_finds_best_threshold(): + # Low-score pred (0.2) would be filtered by fixed threshold=0.5 but oracle should find it + assert f1_det([_evt(0.0, 0.5, 1)], [_evt(0.0, 0.5, 0.2)], score_threshold=None) == pytest.approx(1.0) + +def test_f1det_duplicate_pred_second_is_fp(): + # TP=1, FP=1, FN=0 → F1 = 2/(2+1+0) = 2/3 + preds = [np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 0.9]])] + assert f1_det([_evt(0.0, 0.5, 1)], preds, score_threshold=0.5) == pytest.approx(2 / 3) + +def test_f1det_two_classes_both_perfect(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([[0.0, 0.3, 1.0, 0.0], [0.5, 0.3, 0.0, 1.0]]) + assert f1_det([gt], [pred], score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_two_classes_one_missed(): + gt = np.array([[0.0, 0.3, 1, 0], [0.5, 0.3, 0, 1]]) + pred = np.array([ + [0.0, 0.3, 1.0, 0.0], # class 0: perfect hit + [0.0, 0.1, 0.0, 0.9], # class 1: no overlap with GT at [0.5, 0.3] + ]) + # class 0: F1=1.0, class 1: TP=0 FP=1 FN=1 → F1=0 → mean=0.5 + assert f1_det([gt], [pred], score_threshold=0.5) == pytest.approx(0.5) + +def test_f1det_multi_series_both_matched(): + y_true = [_evt(0.0, 0.5, 1), _evt(0.2, 0.4, 1)] + y_pred = [_evt(0.0, 0.5, 0.9), _evt(0.2, 0.4, 0.8)] + assert f1_det(y_true, y_pred, score_threshold=0.5) == pytest.approx(1.0) + +def test_f1det_prediction_in_wrong_series_is_fp(): + y_true = [_evt(0.0, 0.5, 1), _no_evts(1)] + y_pred = [_no_evts(1), _evt(0.0, 0.5, 0.9)] + # TP=0, FP=1, FN=1 → F1=0 + assert f1_det(y_true, y_pred, score_threshold=0.5) == pytest.approx(0.0)