diff --git a/.gitignore b/.gitignore index 091a223..a90798c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ __cache__ *.egg-info +*.pth .coverage **/outputs joblib/ @@ -19,3 +20,4 @@ coverage.xml # Data directories data/ exploratory/ +src/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a31f116..dca8836 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,13 +8,13 @@ repos: - id: check-json - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.3.0 hooks: - id: flake8 args: [--max-line-length=79] # Customize flake8 options here - - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.6.0 + - repo: https://github.com/hhatto/autopep8 + rev: v2.3.2 hooks: - id: autopep8 - args: [--max-line-length=79, --in-place] \ No newline at end of file + args: [--max-line-length=79, --in-place] diff --git a/benchmark_utils/download.py b/benchmark_utils/download.py new file mode 100644 index 0000000..becb2c9 --- /dev/null +++ b/benchmark_utils/download.py @@ -0,0 +1,77 @@ +"""Shared download helper for the TSB-UAD public dataset bundle. +""" +from pathlib import Path + +from benchopt import config + + +_BUNDLE_URL = "https://www.thedatum.org/datasets/TSB-UAD-Public.zip" +_BUNDLE_SHA256 = ( + "ff4aa83a5a111835d410d962152e8dbebcda1039b778bae45b6b9c3f46dd49a1" +) +_BUNDLE_FILENAME = "TSB-UAD-Public.zip" +_BUNDLE_ROOT = "TSB-UAD-Public" + +# Map benchmark dataset name -> subdirectory inside the TSB-UAD bundle. +_SUBDIR = { + "DAPHNET": "Daphnet", + "DODGERS": "Dodgers", + "ECG": "ECG", + "GENESIS": "Genesis", + "GHL": "GHL", + "IOPS": "IOPS", + "KDD21": "KDD21", + "MGAB": "MGAB", + "MITDB": "MITDB", + "NAB": "NAB", + "OCCUPANCY": "Occupancy", + "OPPORTUNITY": "OPPORTUNITY", + "SENSORSCOPE": "SensorScope", + "SMD": "SMD", + "SVDB": "SVDB", + "YAHOO": "YAHOO", +} + + +def fetch_tsb_uad(name: str) -> Path: + """Return the local directory holding TSB-UAD's ``.out`` files for *name*. + + The bundle is downloaded once into + ``benchopt.config.get_data_path("TSB-UAD-Public")`` and extracted; + subsequent calls are cache hits. + """ + if name not in _SUBDIR: + raise KeyError( + f"{name!r} is not a TSB-UAD dataset name. " + f"Known names: {sorted(_SUBDIR)}" + ) + + import pooch # local import: only required when downloading + + try: + import tqdm # noqa: F401 + progressbar = True + except ImportError: + progressbar = False + + cache_root = Path(config.get_data_path(key=_BUNDLE_ROOT)) + cache_root.mkdir(parents=True, exist_ok=True) + + registry = pooch.create( + path=cache_root, + base_url="https://www.thedatum.org/datasets/", + registry={_BUNDLE_FILENAME: f"sha256:{_BUNDLE_SHA256}"}, + urls={_BUNDLE_FILENAME: _BUNDLE_URL}, + ) + registry.fetch( + _BUNDLE_FILENAME, + processor=pooch.Unzip(extract_dir="."), + progressbar=progressbar, + ) + + subdir = cache_root / _BUNDLE_ROOT / _SUBDIR[name] + if not subdir.exists(): + raise FileNotFoundError( + f"Expected {subdir} after extracting the TSB-UAD bundle." + ) + return subdir diff --git a/datasets/dodgers.py b/datasets/dodgers.py index f3c6879..8d3b7a2 100644 --- a/datasets/dodgers.py +++ b/datasets/dodgers.py @@ -1,10 +1,10 @@ -from benchopt import BaseDataset, config +from benchopt import BaseDataset from pathlib import Path import numpy as np import pandas as pd -PATH = config.get_data_path("DODGERS") +from benchmark_utils.download import fetch_tsb_uad def load_data(db_path, record_ids=None, verbose=False): @@ -90,6 +90,8 @@ def load_data(db_path, record_ids=None, verbose=False): class Dataset(BaseDataset): name = "DODGERS" + requirements = ["pip:pooch"] + parameters = { # "recordings_id": [["101"]], "recordings_id": [None], @@ -99,11 +101,13 @@ class Dataset(BaseDataset): def get_data(self): """Load the DODGERS dataset.""" + path = fetch_tsb_uad("DODGERS") + # X shape (n_recordings, n_samples) # y shape (n_recordings, n_samples) if self.recordings_id in (["all"], "all"): self.recordings_id = None - X, y_true = load_data(PATH, self.recordings_id) + X, y_true = load_data(path, self.recordings_id) X_test = X.copy() y_test = y_true.copy() diff --git a/datasets/mitdb.py b/datasets/mitdb.py index 7f811d0..5cf8668 100644 --- a/datasets/mitdb.py +++ b/datasets/mitdb.py @@ -1,10 +1,10 @@ -from benchopt import BaseDataset, config +from benchopt import BaseDataset from pathlib import Path import numpy as np import pandas as pd -PATH = config.get_data_path("MITDB") +from benchmark_utils.download import fetch_tsb_uad def load_mitdb_data(db_path, record_ids=None, verbose=False): @@ -102,6 +102,8 @@ def load_mitdb_data(db_path, record_ids=None, verbose=False): class Dataset(BaseDataset): name = "MITDB" + requirements = ["pip:pooch"] + parameters = { "recordings_id": [["100", "201", "109", "105", "111", "221"]], "debug": [False], @@ -110,11 +112,13 @@ class Dataset(BaseDataset): def get_data(self): """Load the MITDB dataset.""" + path = fetch_tsb_uad("MITDB") + # X shape (n_recordings, n_samples) # y shape (n_recordings, n_samples) if self.recordings_id in (["all"], "all"): self.recordings_id = None - X, y_true = load_mitdb_data(PATH, self.recordings_id) + X, y_true = load_mitdb_data(path, self.recordings_id) X_test = X.copy() y_test = y_true.copy() diff --git a/datasets/nab.py b/datasets/nab.py index 20a0960..88b1d0f 100644 --- a/datasets/nab.py +++ b/datasets/nab.py @@ -1,10 +1,10 @@ -from benchopt import BaseDataset, config +from benchopt import BaseDataset from pathlib import Path import numpy as np import pandas as pd -PATH = config.get_data_path("NAB") +from benchmark_utils.download import fetch_tsb_uad def load_data(db_path, record_ids=None, verbose=False): @@ -88,6 +88,8 @@ def load_data(db_path, record_ids=None, verbose=False): class Dataset(BaseDataset): name = "NAB" + requirements = ["pip:pooch"] + parameters = { "recordings_id": [["art0"], ["art1"], ["CloudWatch"]], "debug": [False], @@ -96,9 +98,11 @@ class Dataset(BaseDataset): def get_data(self): """Load the NAB dataset.""" + path = fetch_tsb_uad("NAB") + # X shape (n_recordings, n_samples) # y shape (n_recordings, n_samples) - X, y_true = load_data(PATH, self.recordings_id) + X, y_true = load_data(path, self.recordings_id) X_test = X.copy() y_test = y_true.copy() diff --git a/test_config.py b/test_config.py index 3606a74..cd15317 100644 --- a/test_config.py +++ b/test_config.py @@ -56,9 +56,9 @@ def check_test_solver_run(benchmark, solver_class): def check_test_dataset_get_data(benchmark, dataset_class): if dataset_class.name.lower() in [ - "daphnet", "dodgers", "ecg", "genesis", "ghl", - "iops", "kdd21", "mgab", "mitdb", "nab", + "daphnet", "ecg", "genesis", "ghl", + "iops", "kdd21", "mgab", "occupancy", "opportunity", "sensorscope", "smd", - "svdb", "yahoo" + "svdb", "yahoo", "nab", "mitdb", "dodgers", ]: pytest.xfail(f"{dataset_class.name} dataset is not downloaded.")