From 84199b99ce0d01476f06be943863e11c6f8e3e6a Mon Sep 17 00:00:00 2001 From: JesusGF1 Date: Sun, 26 Apr 2026 17:45:34 -0700 Subject: [PATCH 1/8] Add HIPPIE optional dependency for cell-type classification Integrates the pretrained HIPPIE multimodal model (Jesusgf23/hippie on HuggingFace) as an optional spikelab[hippie] extra. Neurons are encoded via waveform + ISI + autocorrelogram into a 30-D latent space, then visualised with UMAP and clustered with HDBSCAN. Changes: - spikedata/hippie_adapter.py: feature extraction from SpikeData and full classify_neurons() pipeline - mcp_server/tools/analysis.py: classify_neurons_hippie() async tool - mcp_server/server.py: tool schema and _TOOL_DISPATCH entry - pyproject.toml: hippie optional extra (torch, hf-hub, umap-learn, hdbscan) - tests/test_hippie_adapter.py: 13 unit tests, skipped without hippie - docs/guides/hippie_classification.rst: full pipeline guide covering Kilosort+raw-bin (Path A), NWB (Path B), pre-computed waveforms (Path C), MCP agent prompts, and known MCP per-unit waveform limitation - docs/guides/hippie_changes.md: codebase change schematic Co-Authored-By: Claude Sonnet 4.6 --- docs/source/guides/hippie_changes.md | 375 ++++++++++++++++++ docs/source/guides/hippie_classification.rst | 381 +++++++++++++++++++ docs/source/guides/index.rst | 1 + pyproject.toml | 8 + src/spikelab/mcp_server/server.py | 74 ++++ src/spikelab/mcp_server/tools/analysis.py | 100 +++++ src/spikelab/spikedata/hippie_adapter.py | 221 +++++++++++ tests/test_hippie_adapter.py | 185 +++++++++ 8 files changed, 1345 insertions(+) create mode 100644 docs/source/guides/hippie_changes.md create mode 100644 docs/source/guides/hippie_classification.rst create mode 100644 src/spikelab/spikedata/hippie_adapter.py create mode 100644 tests/test_hippie_adapter.py diff --git a/docs/source/guides/hippie_changes.md b/docs/source/guides/hippie_changes.md new file mode 100644 index 00000000..fac102c4 --- /dev/null +++ b/docs/source/guides/hippie_changes.md @@ -0,0 +1,375 @@ +# HIPPIE Integration — Codebase Changes + +## Overview + +HIPPIE is added as an **optional dependency** of SpikeLab. +Users who want cell-type classification install it with: + +```bash +pip install "spikelab[hippie]" +``` + +Nothing in the base SpikeLab install changes — all new code is either in new files or behind lazy imports that only run when HIPPIE is actually present. + +--- + +## Repository map + +``` +HIPPIE/ +└── hippie/ + ├── __init__.py ← MODIFIED (+2 lines) + ├── inference.py ← NEW + └── checkpoint.py ← NEW + +SpikeLab/ +├── pyproject.toml ← MODIFIED (+7 lines) +├── src/spikelab/ +│ ├── spikedata/ +│ │ └── hippie_adapter.py ← NEW +│ └── mcp_server/ +│ ├── tools/ +│ │ └── analysis.py ← MODIFIED (+90 lines at EOF) +│ └── server.py ← MODIFIED (+60 lines: schema + dispatch) +├── tests/ +│ └── test_hippie_adapter.py ← NEW +└── docs/source/guides/ + ├── index.rst ← MODIFIED (+1 line) + └── hippie_classification.rst ← NEW +``` + +--- + +## File-by-file changes + +### `HIPPIE/hippie/__init__.py` — modified + +**What changed:** Added one export line. + +```python +# Before (4 lines) +from .multimodal_model import MultiModalCVAE, ... +from .dataloading import ... +from .augmentations import ... +from .backbones import ... + +# After (+1 line) +from .inference import HIPPIEClassifier, TECHNOLOGY_IDS +``` + +**Why:** Allows `from hippie import HIPPIEClassifier` — the clean public API. + +--- + +### `HIPPIE/hippie/checkpoint.py` — new file + +**Purpose:** Load a pretrained checkpoint and return a ready `MultiModalCVAE`. + +``` +infer_model_dims(state_dict) + → reads source_embed.weight / class_embed.weight shapes + → returns (num_sources, num_classes) + +build_model(ckpt_path) + → torch.load(..., weights_only=False) + → strips Lightning "model." prefix from state_dict keys + → calls infer_model_dims to auto-detect architecture + → instantiates MultiModalCVAE with hardcoded inference config: + modalities = {"wave": 50, "isi": 100, "acg": 100} + z_dim = 30 + config = ExperimentConfigs.class_decoder_source_bn_aug_reg() + backbone_base_width = 64 + → model.load_state_dict(sd, strict=False) + → returns model in eval() mode on CPU +``` + +**Key detail:** `strict=False` on `load_state_dict` — the checkpoint may contain +decoder weights that are not needed for inference; this suppresses the error. + +--- + +### `HIPPIE/hippie/inference.py` — new file + +**Purpose:** High-level inference API used by both direct callers and the SpikeLab adapter. + +``` +TECHNOLOGY_IDS: dict + {"neuropixels": 0, "silicon_probe": 1, "juxtacellular": 2, "tetrodes": 3} + +class HIPPIEClassifier: + + from_pretrained(repo_id, filename, device, cache_dir) + → hf_hub_download(repo_id, filename) + → build_model(ckpt_path) + → returns HIPPIEClassifier + + from_checkpoint(checkpoint_path, device) + → build_model(str(path)) + → returns HIPPIEClassifier + + get_embeddings(wave, isi, acg, tech_id, batch_size) → np.ndarray (N, 30) + → accepts tech_id as int or name string + → batches input; calls model.encode({wave, isi, acg}, source_labels=tech_id) + → returns z_mean concatenated across batches + + umap_reduce(embeddings, n_components, n_neighbors, min_dist, metric, random_state) + → static method; lazy-imports umap.UMAP + → returns (N, n_components) float32 coords + + hdbscan_cluster(embeddings, min_cluster_size, min_samples, metric) + → static method; lazy-imports hdbscan.HDBSCAN + → returns (N,) int32 labels; -1 = noise +``` + +**Input contract for `get_embeddings`:** + +| Modality | Shape | Normalization | +|----------|---------|------------------------------------------| +| `wave` | (N, 50) | min-max → [-1, 1] | +| `isi` | (N, 100)| log(x+1), then min-max → [-1, 1] | +| `acg` | (N, 100)| min-max → [-1, 1] | + +--- + +### `SpikeLab/pyproject.toml` — modified + +**What changed:** Added the `hippie` optional extra. + +```toml +[project.optional-dependencies] +# ... existing extras unchanged ... +hippie = [ + "hippie @ git+https://github.com/braingeneers/HIPPIE.git", + "huggingface-hub>=0.20", + "torch>=2.0", + "umap-learn>=0.5.0", + "hdbscan>=0.8", +] +``` + +**Note:** PyTorch with CUDA must still be installed separately if GPU inference is desired. + +--- + +### `SpikeLab/src/spikelab/spikedata/hippie_adapter.py` — new file + +**Purpose:** Bridge between `SpikeData` and the HIPPIE encoder. +Not exported from `spikedata/__init__.py` — imported explicitly to avoid import errors for users without HIPPIE. + +``` +# Guards +_require_hippie() + → tries `import hippie`; raises ImportError with install hint if missing + +# Preprocessing helpers (pure numpy/torch, no HIPPIE dependency) +_preprocess_waveform(wave, target=50) → (50,) float32 + → F.interpolate to target length, min-max → [-1, 1] + +_isi_histogram(spike_times, n_bins=100) → (100,) float32 + → np.diff(sorted spikes) * 1000 → ms + → np.histogram with log-spaced bins [1 ms, 5000 ms] + → log1p transform, then min-max → [-1, 1] + → silent neurons (< 2 spikes) return flat -1 array + +_autocorrelogram(spike_times, max_lag_ms=100, n_bins=100) → (100,) float32 + → forward-looking searchsorted loop (O(N) per lag window) + → normalised to sum=1, then min-max → [-1, 1] + → empty trains return zeros + +# Public API +extract_features(sd, isi_bins, acg_bins, acg_max_lag_ms) → dict + → reads avg_waveform from neuron_attributes (raises ValueError if missing) + → stacks waveforms, ISI histograms, ACGs into (N, bins) arrays + → returns {"wave": (N,50), "isi": (N,100), "acg": (N,100)} + +classify_neurons(sd, repo_id, tech_id, device, run_umap, run_hdbscan, + umap_kwargs, hdbscan_kwargs, batch_size, cache_dir) → dict + → calls _require_hippie() then imports HIPPIEClassifier inside function + → calls extract_features(sd) + → HIPPIEClassifier.from_pretrained(repo_id, device, cache_dir) + → clf.get_embeddings(wave, isi, acg, tech_id, batch_size) + → optionally: clf.umap_reduce(embeddings, **umap_kwargs) + → optionally: clf.hdbscan_cluster(umap_coords or embeddings, **hdbscan_kwargs) + → returns {"embeddings", "umap_coords"?, "cluster_labels"?} +``` + +--- + +### `SpikeLab/src/spikelab/mcp_server/tools/analysis.py` — modified + +**What changed:** ~90 lines appended at end of file. + +``` +async def classify_neurons_hippie( + workspace_id, namespace, + tech_id=0, run_umap=True, run_hdbscan=True, + min_cluster_size=5, umap_n_neighbors=30, umap_min_dist=0.1, + device="cpu", cache_dir=None +) → dict + + Calls: + _get_workspace(workspace_id) + _get_spikedata(ws, namespace) + classify_neurons(sd, ...) ← imported lazily inside function + sd.set_neuron_attribute(...) ← writes hippie_embedding, + hippie_umap_x/y, hippie_cluster + ws.store(namespace, "spikedata", sd) + + Returns JSON summary: + n_neurons, embedding_dim, + umap_computed, hdbscan_computed, + n_clusters, n_noise_neurons, + neuron_attributes_added +``` + +**Import note:** `from ....spikedata.hippie_adapter import classify_neurons` is +inside the function body — the MCP server starts fine without HIPPIE installed +and only fails at call time with a clear error message. + +--- + +### `SpikeLab/src/spikelab/mcp_server/server.py` — modified + +**Two edits:** + +1. **Tool schema** — inserted before the "Workspace management tools" section: + +```python +types.Tool( + name="classify_neurons_hippie", + description="Classify neurons using the pretrained HIPPIE multimodal model ...", + inputSchema={ + "required": ["workspace_id", "namespace"], + "properties": { + workspace_id, namespace, + tech_id (int, default 0), + run_umap (bool, default True), + run_hdbscan (bool, default True), + min_cluster_size (int, default 5), + umap_n_neighbors (int, default 30), + umap_min_dist (float, default 0.1), + device (str, default "cpu"), + cache_dir (str, optional), + } + } +) +``` + +2. **Dispatch entry** — added to `_TOOL_DISPATCH`: + +```python +"classify_neurons_hippie": analysis.classify_neurons_hippie, +``` + +--- + +### `SpikeLab/tests/test_hippie_adapter.py` — new file + +**13 tests across 5 classes**, all skipped automatically if `hippie` is not installed (`pytest.importorskip`). + +``` +TestPreprocessWaveform (3 tests) + ✓ output shape is (50,) + ✓ values in [-1, 1] + ✓ flat waveform (all zeros) does not crash + +TestISIHistogram (3 tests) + ✓ output shape is (100,) + ✓ values in [-1, 1] + ✓ silent neuron (1 spike) does not crash + +TestAutocorrelogram (3 tests) + ✓ output shape is (100,) + ✓ values in [-1, 1] + ✓ empty spike train returns zeros + +TestExtractFeatures (3 tests) + ✓ shapes (N, 50) / (N, 100) / (N, 100) + ✓ dtype is float32 + ✓ missing avg_waveform raises ValueError + +TestClassifyNeurons (4 tests — fully mocked, no download) + @patch("hippie.inference.HIPPIEClassifier") ← correct patch target + ✓ returns embeddings + umap_coords + cluster_labels by default + ✓ run_umap=False, run_hdbscan=False omits those keys + ✓ embedding shape is (N, 30) + ✓ tech_id string is forwarded unchanged to get_embeddings +``` + +--- + +### `SpikeLab/docs/source/guides/hippie_classification.rst` — new file + +Full Sphinx guide covering: +- Installation +- Quick start +- Return value table +- Technology ID table +- Advanced options (custom UMAP/HDBSCAN params, embeddings only, direct HIPPIE API) +- MCP server usage + example agent prompts +- How it works (feature extraction → encoding → UMAP → HDBSCAN) +- Checkpoint info + +### `SpikeLab/docs/source/guides/index.rst` — modified + +Added `hippie_classification` to the toctree (+1 line). + +--- + +## Data flow + +``` +SpikeData + │ + │ avg_waveform (neuron_attributes) + │ spike trains (sd.train) + ▼ +hippie_adapter.extract_features() + │ + ├── _preprocess_waveform() → (N, 50) min-max [-1,1] + ├── _isi_histogram() → (N, 100) log1p + min-max [-1,1] + └── _autocorrelogram() → (N, 100) normalized + min-max [-1,1] + │ + ▼ +HIPPIEClassifier.get_embeddings() + │ + │ model.encode({wave, isi, acg}, source_labels=tech_id) + │ ↳ ResNet18Enc × 3 → fusion_encoder → z_mean (30-D) + │ + ▼ + (N, 30) embeddings + │ + ├──[run_umap=True]──► umap_reduce() → (N, 2) coords + │ cosine metric, n_neighbors=30 + │ + └──[run_hdbscan=True]► hdbscan_cluster() on umap_coords + → (N,) labels (-1 = noise) + │ + ▼ +classify_neurons() returns + {"embeddings": (N,30), "umap_coords": (N,2), "cluster_labels": (N,)} + │ + ▼ [via MCP tool classify_neurons_hippie] +neuron_attributes: + hippie_embedding (N, 30) + hippie_umap_x (N,) + hippie_umap_y (N,) + hippie_cluster (N,) -1 = noise +``` + +--- + +## Bug fixed during audit + +| File | Issue | Fix | +|------|-------|-----| +| `tests/test_hippie_adapter.py` | `@patch("spikelab.spikedata.hippie_adapter.HIPPIEClassifier")` — wrong target; `HIPPIEClassifier` is imported *inside* `classify_neurons` via `from hippie.inference import HIPPIEClassifier`, so the mock must be placed on the source module | Changed to `@patch("hippie.inference.HIPPIEClassifier")` | + +--- + +## What was NOT changed + +- `spikedata/__init__.py` — `hippie_adapter` is intentionally **not** re-exported here; importing it would crash base installs without HIPPIE +- Any existing analysis, data loader, or exporter tool +- HIPPIE training code (`multimodal_model.py`, `dataloading.py`, `augmentations.py`, etc.) +- The HIPPIE `pyproject.toml` diff --git a/docs/source/guides/hippie_classification.rst b/docs/source/guides/hippie_classification.rst new file mode 100644 index 00000000..e9139180 --- /dev/null +++ b/docs/source/guides/hippie_classification.rst @@ -0,0 +1,381 @@ +.. _hippie_classification: + +Cell-Type Classification with HIPPIE +===================================== + +SpikeLab has optional integration with `HIPPIE`_, a pretrained multimodal +generative model for neuron classification. HIPPIE encodes each neuron's +**waveform**, **interspike-interval distribution**, and **autocorrelogram** +into a shared 30-D latent space, then uses UMAP + HDBSCAN for unsupervised +cell-type discovery. + +.. _HIPPIE: https://huggingface.co/Jesusgf23/hippie + +Installation +------------ + +HIPPIE is an optional dependency — install it alongside SpikeLab:: + + pip install "spikelab[hippie]" + +This pulls in PyTorch, HuggingFace Hub, umap-learn, and hdbscan in addition +to the HIPPIE package itself. PyTorch with CUDA must be installed separately +if GPU inference is desired. + +.. note:: + + Nothing in the base SpikeLab install is affected. The HIPPIE adapter is + never imported unless you explicitly call it. + +Data requirements +----------------- + +HIPPIE requires three features per neuron: + +* **Average waveform** — stored as ``avg_waveform`` in ``neuron_attributes`` +* **Spike trains** — always present in a ``SpikeData`` object +* **Recording technology** — passed as ``tech_id`` at call time + +The waveform is the only thing that may need preparation. The three +pipelines below cover the most common starting points. + +.. _hippie_pipeline_a: + +Pipeline A — Kilosort output + raw ``.bin`` file +------------------------------------------------- + +This is the typical Neuropixels + Kilosort4 workflow. Kilosort gives you +spike times only; waveforms are extracted from the raw voltage trace +afterward. + +.. note:: + + Attaching raw data to a ``SpikeData`` object is currently a Python-only + step — there is no MCP tool for it. Use this path when scripting directly. + +.. code-block:: python + + import numpy as np + from spikelab.data_loaders import load_spikedata_from_kilosort + from spikelab.spikedata.hippie_adapter import classify_neurons + + # 1. Load spike times from Kilosort output directory + sd = load_spikedata_from_kilosort( + folder_path="/path/to/kilosort_output/", + fs_Hz=30000, # Neuropixels default + cluster_info_tsv="cluster_info.tsv", + include_noise=False, + ) + + # 2. Attach the raw voltage recording + # Shape must be (n_channels, n_samples). + # Use np.memmap for large files to avoid loading everything into RAM. + raw = np.memmap( + "/path/to/recording.ap.bin", + dtype="int16", + mode="r", + shape=(385, n_samples), # adjust n_channels and n_samples + ) + sd.raw_data = raw.astype(np.float32) + sd.raw_time = 30.0 # sampling rate in kHz (30 000 Hz) + + # 3. Extract average waveforms for all units in one call. + # store=True writes avg_waveform into neuron_attributes automatically. + sd.get_waveform_traces( + unit=None, # None = all units + ms_before=1.0, + ms_after=2.0, + store=True, + ) + + # 4. Run HIPPIE: embed → UMAP → HDBSCAN + result = classify_neurons( + sd, + tech_id="neuropixels", # or tech_id=0 + run_umap=True, + run_hdbscan=True, + hdbscan_kwargs={"min_cluster_size": 5}, + ) + + # 5. Store results back into neuron_attributes + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + + n_clusters = (result["cluster_labels"] >= 0).sum() + print(f"{sd.N} neurons → {n_clusters} clustered, " + f"{(result['cluster_labels'] < 0).sum()} noise") + +What ``get_waveform_traces`` does in step 3 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For every unit it finds the peak channel (from ``neuron_to_channel_map``), +extracts a short voltage snippet around each spike, averages them, and +stores the average as ``neuron_attributes[i]["avg_waveform"]``. The adapter +then reads those stored values — no raw data is needed after this point. + + +.. _hippie_pipeline_b: + +Pipeline B — NWB file with raw traces +-------------------------------------- + +NWB files produced by SpikeInterface, the Allen Brain Atlas pipeline, or +similar tools often embed both spike times and raw traces in a single file. +This is the only path that works end-to-end from the **MCP / agent interface**. + +Python +~~~~~~ + +.. code-block:: python + + from spikelab.data_loaders import load_spikedata_from_nwb + from spikelab.spikedata.hippie_adapter import classify_neurons + + sd = load_spikedata_from_nwb("/path/to/recording.nwb") + + # Extract waveforms for all units + sd.get_waveform_traces(unit=None, ms_before=1.0, ms_after=2.0, store=True) + + result = classify_neurons(sd, tech_id="neuropixels") + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +MCP / agent +~~~~~~~~~~~ + +Give an agent these prompts in order: + +.. code-block:: text + + 1. "Load the NWB file at /path/to/recording.nwb" + + 2. "Extract waveforms for all N units with 1 ms before and 2 ms after the spike" + (the agent will call get_waveform_traces once per unit) + + 3. "Classify the neurons using HIPPIE with tech_id 0 (neuropixels)" + + 4. "How many clusters did HIPPIE find? List cluster IDs and neuron counts." + +.. note:: + + The MCP ``get_waveform_traces`` tool extracts one unit at a time. + For a recording with many units the agent needs to call it N times before + HIPPIE can run. See :ref:`hippie_mcp_gap` below. + + +.. _hippie_pipeline_c: + +Pipeline C — Waveforms already available +----------------------------------------- + +If ``avg_waveform`` is already in ``neuron_attributes`` — e.g. loaded from an +HDF5 workspace, set manually from an upstream pipeline, or computed in a +previous session — skip straight to classification: + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import classify_neurons + + # sd already has avg_waveform in neuron_attributes + result = classify_neurons(sd, tech_id="neuropixels") + + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +To check whether waveforms are already present before trying: + +.. code-block:: python + + waves = sd.get_neuron_attribute("avg_waveform") + if waves is None or any(w is None for w in waves): + print("Waveforms missing — run get_waveform_traces first") + else: + print(f"Waveforms ready for {sd.N} units") + +Quick start (waveforms already present) +---------------------------------------- + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import classify_neurons + + result = classify_neurons( + sd, + tech_id="neuropixels", # or 0, 1, 2, 3 — see Technology IDs below + run_umap=True, + run_hdbscan=True, + ) + + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +Return values +~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Key + - Shape + - Description + * - ``embeddings`` + - ``(N, 30)`` + - Latent z_mean vectors from the HIPPIE encoder + * - ``umap_coords`` + - ``(N, 2)`` + - 2-D UMAP projection (present when ``run_umap=True``) + * - ``cluster_labels`` + - ``(N,)`` + - HDBSCAN cluster IDs; ``-1`` = noise / unclustered + (present when ``run_hdbscan=True``) + +Technology IDs +~~~~~~~~~~~~~~ + +The pretrained checkpoint was trained on recordings from four technology +families. Pass the matching ``tech_id`` for best results: + +.. list-table:: + :header-rows: 1 + :widths: 10 25 + + * - ``tech_id`` + - Technology + * - ``0`` / ``"neuropixels"`` + - Neuropixels probes *(default)* + * - ``1`` / ``"silicon_probe"`` + - Silicon probes (non-Neuropixels) + * - ``2`` / ``"juxtacellular"`` + - Juxtacellular recordings + * - ``3`` / ``"tetrodes"`` + - Tetrode recordings + +Advanced options +---------------- + +Tuning UMAP and HDBSCAN +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + result = classify_neurons( + sd, + tech_id=0, + umap_kwargs={"n_neighbors": 15, "min_dist": 0.05}, + hdbscan_kwargs={"min_cluster_size": 10, "min_samples": 5}, + ) + +Embeddings only (no clustering) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Useful when you want to inspect the latent space before deciding on +clustering parameters: + +.. code-block:: python + + result = classify_neurons(sd, run_umap=False, run_hdbscan=False) + embeddings = result["embeddings"] # (N, 30) + +Using the HIPPIE API directly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For full control over preprocessing or batching, import +:class:`hippie.inference.HIPPIEClassifier` directly: + +.. code-block:: python + + from hippie import HIPPIEClassifier + + clf = HIPPIEClassifier.from_pretrained("Jesusgf23/hippie", device="cpu") + + # Inputs must be preprocessed — see hippie_adapter.extract_features() + # for the exact normalization applied to each modality. + embeddings = clf.get_embeddings(wave, isi, acg, tech_id=0) + coords = clf.umap_reduce(embeddings, n_neighbors=30) + labels = clf.hdbscan_cluster(coords, min_cluster_size=5) + + # Load from a local checkpoint instead of HuggingFace + clf2 = HIPPIEClassifier.from_checkpoint("./my_trained_model.ckpt") + +Using via the MCP server +------------------------ + +The ``classify_neurons_hippie`` tool is available in the SpikeLab MCP +server once ``spikelab[hippie]`` is installed. After the tool runs, it +writes ``hippie_embedding``, ``hippie_umap_x``, ``hippie_umap_y``, and +``hippie_cluster`` directly into ``neuron_attributes``, making them +accessible to all downstream tools. + +Example agent prompts +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: text + + "Classify the neurons in this recording using HIPPIE." + + "Run HIPPIE cell-type classification with tech_id 1 (silicon probe)." + + "Embed the neurons with HIPPIE and cluster with HDBSCAN, minimum cluster size 10." + + "What cell types did HIPPIE find? List the cluster IDs and neuron counts." + + "Plot the HIPPIE UMAP coloured by cluster label." + +.. _hippie_mcp_gap: + +Known limitation: MCP waveform extraction is per-unit +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The current ``get_waveform_traces`` MCP tool extracts waveforms for a +**single unit** per call. For a recording with N neurons, an agent must +call it N times before ``classify_neurons_hippie`` can run. + +Workaround until a bulk ``extract_all_waveforms`` tool is added: + +* Use Pipeline A or C in Python, where ``get_waveform_traces(unit=None)`` + processes all units in one call. +* Or pre-compute waveforms in Python and save the workspace; the agent can + then load it and run ``classify_neurons_hippie`` directly. + +How it works +------------ + +1. **Feature extraction** — For each neuron, SpikeLab computes: + + * *Waveform* (50 samples, min-max normalized to [-1, 1]) + * *ISI histogram* (100 log-spaced bins from 1 ms to 5 s, log(x+1) + transformed, then min-max normalized) + * *Autocorrelogram* (100 bins, 0–100 ms, min-max normalized) + +2. **Encoding** — Three modality-specific ResNet18 encoders project each + neuron's features into a shared 30-D latent space, conditioned on the + recording technology (``tech_id``). + +3. **UMAP** — The 30-D embeddings are projected to 2-D using cosine-distance + UMAP for visualization and clustering. + +4. **HDBSCAN** — Density clusters are found in the 2-D UMAP space. + Neurons that do not belong to any cluster receive label ``-1``. + +Checkpoint +---------- + +The pretrained model (``hippie_techcond_v1.ckpt``) is hosted at +`huggingface.co/Jesusgf23/hippie`_. It is downloaded automatically on +first use and cached locally (HuggingFace default cache, or override with +``cache_dir``). The file is 293 MB; subsequent calls use the local cache. + +.. _huggingface.co/Jesusgf23/hippie: https://huggingface.co/Jesusgf23/hippie + +The model was pretrained on 11 labeled electrophysiology datasets spanning +mouse, rat, and macaque across multiple brain regions and recording +technologies. diff --git a/docs/source/guides/index.rst b/docs/source/guides/index.rst index 9ea4cb90..c1f18635 100644 --- a/docs/source/guides/index.rst +++ b/docs/source/guides/index.rst @@ -20,3 +20,4 @@ In-depth guides covering the main workflows in SpikeLab. mcp_server parallel_computing batch_jobs + hippie_classification diff --git a/pyproject.toml b/pyproject.toml index c974e86e..998bff3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,14 @@ ibl = [ # pip install git+https://github.com/int-brain-lab/paper-brain-wide-map.git "ONE-api>=2.0", ] +hippie = [ + # Install HIPPIE from GitHub (or a local editable install for development) + "hippie @ git+https://github.com/braingeneers/HIPPIE.git", + "huggingface-hub>=0.20", + "torch>=2.0", + "umap-learn>=0.5.0", + "hdbscan>=0.8", +] all = [ "pandas>=1.3", "mcp>=0.9.0", diff --git a/src/spikelab/mcp_server/server.py b/src/spikelab/mcp_server/server.py index 16e78f26..35b24b3f 100644 --- a/src/spikelab/mcp_server/server.py +++ b/src/spikelab/mcp_server/server.py @@ -3007,6 +3007,78 @@ async def _list_tools() -> list[types.Tool]: ] ) + # ----------------------------------------------------------------------- + # HIPPIE cell-type classification (optional — requires spikelab[hippie]) + # ----------------------------------------------------------------------- + tools.extend( + [ + types.Tool( + name="classify_neurons_hippie", + description=( + "Classify neurons using the pretrained HIPPIE multimodal model " + "(requires spikelab[hippie]). " + "Downloads the checkpoint from HuggingFace (Jesusgf23/hippie), " + "encodes each neuron's waveform + ISI + autocorrelogram into a " + "30-D latent space, optionally runs UMAP (2-D projection) and " + "HDBSCAN clustering, then stores the results as neuron_attributes " + "(hippie_embedding, hippie_umap_x, hippie_umap_y, hippie_cluster). " + "Requires avg_waveform in neuron_attributes — run " + "get_waveform_traces first if raw data is available." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "tech_id": { + "type": "integer", + "description": ( + "Recording technology index: " + "0=neuropixels (default), 1=silicon_probe, " + "2=juxtacellular, 3=tetrodes" + ), + "default": 0, + }, + "run_umap": { + "type": "boolean", + "description": "Compute 2-D UMAP projection of embeddings (default: true)", + "default": True, + }, + "run_hdbscan": { + "type": "boolean", + "description": "Cluster with HDBSCAN on UMAP coords (default: true)", + "default": True, + }, + "min_cluster_size": { + "type": "integer", + "description": "Minimum neurons per HDBSCAN cluster (default: 5)", + "default": 5, + }, + "umap_n_neighbors": { + "type": "integer", + "description": "UMAP neighbourhood size (default: 30)", + "default": 30, + }, + "umap_min_dist": { + "type": "number", + "description": "UMAP minimum distance between points (default: 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device for the HIPPIE encoder: 'cpu' or 'cuda'", + "default": "cpu", + }, + "cache_dir": { + "type": "string", + "description": "Local directory to cache the downloaded checkpoint (optional)", + }, + }, + "required": ["workspace_id", "namespace"], + }, + ), + ] + ) + # ----------------------------------------------------------------------- # Workspace management tools # ----------------------------------------------------------------------- @@ -4217,6 +4289,8 @@ async def _list_tools() -> list[types.Tool]: "slice_trend": analysis.slice_trend, "slice_stability": analysis.slice_stability, "pairwise_tests": analysis.pairwise_tests, + # HIPPIE cell-type classification + "classify_neurons_hippie": analysis.classify_neurons_hippie, # Export tools "export_to_hdf5_raster": exporters.export_to_hdf5_raster, "export_to_hdf5_ragged": exporters.export_to_hdf5_ragged, diff --git a/src/spikelab/mcp_server/tools/analysis.py b/src/spikelab/mcp_server/tools/analysis.py index 9d0dbbac..f5e19b50 100644 --- a/src/spikelab/mcp_server/tools/analysis.py +++ b/src/spikelab/mcp_server/tools/analysis.py @@ -3061,3 +3061,103 @@ async def pairwise_tests( if out_key: response["key"] = out_key return response + + +# --------------------------------------------------------------------------- +# HIPPIE cell-type classification (optional — requires spikelab[hippie]) +# --------------------------------------------------------------------------- + + +async def classify_neurons_hippie( + workspace_id: str, + namespace: str, + tech_id: int = 0, + run_umap: bool = True, + run_hdbscan: bool = True, + min_cluster_size: int = 5, + umap_n_neighbors: int = 30, + umap_min_dist: float = 0.1, + device: str = "cpu", + cache_dir: Optional[str] = None, +) -> Dict[str, Any]: + """Classify neurons using the pretrained HIPPIE model (requires spikelab[hippie]). + + Downloads the HIPPIE checkpoint from HuggingFace, encodes all neurons into + a 30-dimensional latent space, and optionally runs UMAP projection and + HDBSCAN clustering. Results are stored back into the workspace as + neuron_attributes and as a workspace item. + + Requires avg_waveform to be present in neuron_attributes — run + get_waveform_traces first if raw data is available. + + Args: + workspace_id: Workspace ID. + namespace: Recording namespace. + tech_id: Recording technology index (0=neuropixels, 1=silicon_probe, + 2=juxtacellular, 3=tetrodes). + run_umap: Compute 2-D UMAP projection and store coordinates. + run_hdbscan: Cluster with HDBSCAN (-1 = noise). + min_cluster_size: Minimum neurons per HDBSCAN cluster. + umap_n_neighbors: UMAP neighbourhood size. + umap_min_dist: UMAP minimum distance between points. + device: "cuda" or "cpu" for the HIPPIE encoder. + cache_dir: Directory to cache the downloaded checkpoint. + """ + from ....spikedata.hippie_adapter import classify_neurons + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + umap_kwargs = {"n_neighbors": umap_n_neighbors, "min_dist": umap_min_dist} + hdbscan_kwargs = {"min_cluster_size": min_cluster_size} + + result = classify_neurons( + sd, + tech_id=tech_id, + device=device, + run_umap=run_umap, + run_hdbscan=run_hdbscan, + umap_kwargs=umap_kwargs, + hdbscan_kwargs=hdbscan_kwargs, + cache_dir=cache_dir, + ) + + # Store results as neuron_attributes + sd.set_neuron_attribute("hippie_embedding", result["embeddings"].tolist()) + if "umap_coords" in result: + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0].tolist()) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1].tolist()) + if "cluster_labels" in result: + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"].tolist()) + + # Persist the updated SpikeData + ws.store(namespace, "spikedata", sd) + + n_clusters = ( + int(np.unique(result["cluster_labels"][result["cluster_labels"] >= 0]).size) + if "cluster_labels" in result + else None + ) + n_noise = ( + int((result["cluster_labels"] < 0).sum()) + if "cluster_labels" in result + else None + ) + + added_attrs = ["hippie_embedding"] + if "umap_coords" in result: + added_attrs += ["hippie_umap_x", "hippie_umap_y"] + if "cluster_labels" in result: + added_attrs.append("hippie_cluster") + + return { + "workspace_id": workspace_id, + "namespace": namespace, + "n_neurons": int(result["embeddings"].shape[0]), + "embedding_dim": int(result["embeddings"].shape[1]), + "umap_computed": "umap_coords" in result, + "hdbscan_computed": "cluster_labels" in result, + "n_clusters": n_clusters, + "n_noise_neurons": n_noise, + "neuron_attributes_added": added_attrs, + } diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py new file mode 100644 index 00000000..39a902e8 --- /dev/null +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -0,0 +1,221 @@ +"""HIPPIE neuron classification adapter for SpikeData. + +Requires: pip install spikelab[hippie] + +Workflow: + 1. extract_features(sd) — compute wave/ISI/ACG arrays from a SpikeData object + 2. classify_neurons(sd) — one-call pipeline: embed → UMAP → HDBSCAN +""" + +from __future__ import annotations + +from typing import Optional, Union + +import numpy as np + + +# ISI histogram parameters (log-spaced, matching HIPPIE training data convention) +_ISI_N_BINS = 100 +_ISI_MIN_MS = 1.0 +_ISI_MAX_MS = 5000.0 + +# ACG parameters +_ACG_N_BINS = 100 +_ACG_MAX_LAG_MS = 100.0 + + +def _require_hippie(): + try: + import hippie # noqa: F401 + except ImportError: + raise ImportError( + "HIPPIE is required for neuron classification. " + "Install it with: pip install spikelab[hippie]" + ) + + +# ------------------------------------------------------------------ +# Per-neuron preprocessing helpers +# ------------------------------------------------------------------ + +def _preprocess_waveform(wave: np.ndarray, target: int = 50) -> np.ndarray: + """Resample waveform to target length and min-max normalize to [-1, 1].""" + import torch + import torch.nn.functional as F + + t = torch.as_tensor(wave, dtype=torch.float32).view(1, 1, -1) + t = F.interpolate(t, size=(target,), mode="linear", align_corners=False).squeeze() + mn, mx = t.min().item(), t.max().item() + if mx > mn: + t = (t - mn) / (mx - mn) * 2.0 - 1.0 + return t.numpy().astype(np.float32) + + +def _isi_histogram(spike_times: np.ndarray, n_bins: int = _ISI_N_BINS) -> np.ndarray: + """Compute a log-spaced ISI histogram, log(x+1)-transformed and min-max normalized.""" + isis_ms = np.diff(np.sort(spike_times)) * 1000.0 + isis_ms = isis_ms[isis_ms > 0] + if len(isis_ms) < 2: + return np.full(n_bins, -1.0, dtype=np.float32) # return flat -1 for silent neurons + + bins = np.logspace(np.log10(_ISI_MIN_MS), np.log10(_ISI_MAX_MS), n_bins + 1) + hist, _ = np.histogram(isis_ms, bins=bins, density=True) + hist = hist.astype(np.float32) + + # log(x+1) transform then min-max to [-1, 1] — matches MultiModalEphysDataset + hist = np.log1p(hist) + mn, mx = hist.min(), hist.max() + if mx > mn: + hist = (hist - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 + return hist + + +def _autocorrelogram( + spike_times: np.ndarray, + max_lag_ms: float = _ACG_MAX_LAG_MS, + n_bins: int = _ACG_N_BINS, +) -> np.ndarray: + """Compute a half-sided autocorrelogram (forward lags only), min-max normalized.""" + if len(spike_times) < 2: + return np.zeros(n_bins, dtype=np.float32) + + st_ms = np.sort(spike_times) * 1000.0 + bin_edges = np.linspace(0.0, max_lag_ms, n_bins + 1) + counts = np.zeros(n_bins, dtype=np.float64) + + for i in range(len(st_ms)): + hi = np.searchsorted(st_ms, st_ms[i] + max_lag_ms, side="right") + lo = i + 1 + if lo < hi: + diffs = st_ms[lo:hi] - st_ms[i] + counts += np.histogram(diffs, bins=bin_edges)[0] + + total = counts.sum() + if total > 0: + counts /= total + + acg = counts.astype(np.float32) + mn, mx = acg.min(), acg.max() + if mx > mn: + acg = (acg - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 + return acg + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + +def extract_features( + sd, + isi_bins: int = _ISI_N_BINS, + acg_bins: int = _ACG_N_BINS, + acg_max_lag_ms: float = _ACG_MAX_LAG_MS, +) -> dict: + """Extract waveform, ISI, and ACG features from a SpikeData object. + + Waveforms are read from ``neuron_attributes["avg_waveform"]``. Call + ``sd.get_waveform_traces()`` first if raw_data is available and + avg_waveform has not yet been computed. + + Args: + sd: SpikeData instance with spike trains and avg_waveform attributes. + isi_bins: Number of log-spaced ISI histogram bins. + acg_bins: Number of autocorrelogram bins. + acg_max_lag_ms: Maximum lag for the autocorrelogram (milliseconds). + + Returns: + dict with keys: + - "wave": (N, 50) min-max normalized waveforms + - "isi": (N, 100) log-transformed, normalized ISI histograms + - "acg": (N, 100) normalized autocorrelograms + """ + waves = sd.get_neuron_attribute("avg_waveform") + if waves is None or any(w is None for w in waves): + raise ValueError( + "avg_waveform not found in neuron_attributes. " + "Call sd.get_waveform_traces() first, or set avg_waveform manually." + ) + + wave_arr = np.stack([_preprocess_waveform(np.asarray(w)) for w in waves]) + isi_arr = np.stack([_isi_histogram(t, n_bins=isi_bins) for t in sd.train]) + acg_arr = np.stack([ + _autocorrelogram(t, max_lag_ms=acg_max_lag_ms, n_bins=acg_bins) + for t in sd.train + ]) + + return {"wave": wave_arr, "isi": isi_arr, "acg": acg_arr} + + +def classify_neurons( + sd, + repo_id: str = "Jesusgf23/hippie", + tech_id: Union[int, str] = 0, + device: str = "cpu", + run_umap: bool = True, + run_hdbscan: bool = True, + umap_kwargs: Optional[dict] = None, + hdbscan_kwargs: Optional[dict] = None, + batch_size: int = 256, + cache_dir: Optional[str] = None, +) -> dict: + """Classify neurons in a SpikeData object using HIPPIE. + + Downloads the pretrained HIPPIE checkpoint, encodes all neurons into + the latent space, and optionally runs UMAP dimensionality reduction + followed by HDBSCAN clustering. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + repo_id: HuggingFace repository ID for the HIPPIE checkpoint. + tech_id: Recording technology — int index or one of: + "neuropixels" (0), "silicon_probe" (1), + "juxtacellular" (2), "tetrodes" (3). + device: "cuda" or "cpu". + run_umap: Compute 2-D UMAP projection of the embeddings. + run_hdbscan: Cluster with HDBSCAN (applied on UMAP coords when + run_umap=True, otherwise on raw embeddings). + umap_kwargs: Extra keyword arguments for HIPPIEClassifier.umap_reduce(). + hdbscan_kwargs: Extra keyword arguments for HIPPIEClassifier.hdbscan_cluster(). + batch_size: Neurons per forward pass. + cache_dir: Local directory to cache the downloaded checkpoint. + + Returns: + dict with keys: + - "embeddings": (N, 30) latent z_mean vectors + - "umap_coords": (N, 2) UMAP coordinates (present if run_umap=True) + - "cluster_labels":(N,) HDBSCAN labels, -1=noise (present if run_hdbscan=True) + + Example: + >>> from spikelab.spikedata.hippie_adapter import classify_neurons + >>> result = classify_neurons(sd, tech_id="neuropixels") + >>> sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + >>> sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + """ + _require_hippie() + from hippie.inference import HIPPIEClassifier + + features = extract_features(sd) + + clf = HIPPIEClassifier.from_pretrained( + repo_id=repo_id, device=device, cache_dir=cache_dir + ) + embeddings = clf.get_embeddings( + features["wave"], + features["isi"], + features["acg"], + tech_id=tech_id, + batch_size=batch_size, + ) + + result: dict = {"embeddings": embeddings} + + if run_umap: + result["umap_coords"] = clf.umap_reduce(embeddings, **(umap_kwargs or {})) + + if run_hdbscan: + cluster_input = result.get("umap_coords", embeddings) + result["cluster_labels"] = clf.hdbscan_cluster( + cluster_input, **(hdbscan_kwargs or {}) + ) + + return result diff --git a/tests/test_hippie_adapter.py b/tests/test_hippie_adapter.py new file mode 100644 index 00000000..3276601a --- /dev/null +++ b/tests/test_hippie_adapter.py @@ -0,0 +1,185 @@ +"""Tests for the HIPPIE cell-type classification adapter. + +All tests mock the HuggingFace download and model forward pass so the +293 MB checkpoint is never fetched during CI. +""" + +import numpy as np +import pytest +from unittest.mock import MagicMock, patch + +# Skip every test in this file if hippie is not installed +hippie = pytest.importorskip("hippie", reason="spikelab[hippie] not installed") + +from spikelab.spikedata.hippie_adapter import ( + _isi_histogram, + _autocorrelogram, + _preprocess_waveform, + extract_features, + classify_neurons, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_spike_train(n_spikes=200, duration_s=60.0, seed=0): + rng = np.random.default_rng(seed) + return np.sort(rng.uniform(0, duration_s, n_spikes)) + + +def _make_waveform(n=82, seed=0): + rng = np.random.default_rng(seed) + t = np.linspace(0, 2 * np.pi, n) + return np.sin(t) + rng.normal(0, 0.05, n) + + +def _make_spikedata(n_units=10, seed=0): + """Return a minimal SpikeData with avg_waveform in neuron_attributes.""" + from spikelab.spikedata import SpikeData + + rng = np.random.default_rng(seed) + trains = [_make_spike_train(200 + i * 10, seed=seed + i) for i in range(n_units)] + waveforms = [_make_waveform(seed=seed + i) for i in range(n_units)] + attrs = [{"avg_waveform": w} for w in waveforms] + return SpikeData(trains, length=60.0, neuron_attributes=attrs) + + +# --------------------------------------------------------------------------- +# Unit tests for preprocessing helpers +# --------------------------------------------------------------------------- + +class TestPreprocessWaveform: + def test_output_shape(self): + wave = _make_waveform(82) + out = _preprocess_waveform(wave, target=50) + assert out.shape == (50,) + + def test_range(self): + wave = _make_waveform(82) + out = _preprocess_waveform(wave) + assert out.min() >= -1.0 - 1e-5 + assert out.max() <= 1.0 + 1e-5 + + def test_flat_waveform_does_not_crash(self): + wave = np.zeros(50) + out = _preprocess_waveform(wave) + assert out.shape == (50,) + assert np.isfinite(out).all() + + +class TestISIHistogram: + def test_output_shape(self): + st = _make_spike_train(200) + hist = _isi_histogram(st, n_bins=100) + assert hist.shape == (100,) + + def test_range(self): + st = _make_spike_train(200) + hist = _isi_histogram(st) + assert hist.min() >= -1.0 - 1e-5 + assert hist.max() <= 1.0 + 1e-5 + + def test_silent_neuron(self): + hist = _isi_histogram(np.array([0.5]), n_bins=100) + assert hist.shape == (100,) + assert np.isfinite(hist).all() + + +class TestAutocorrelogram: + def test_output_shape(self): + st = _make_spike_train(200) + acg = _autocorrelogram(st, n_bins=100) + assert acg.shape == (100,) + + def test_range(self): + st = _make_spike_train(200) + acg = _autocorrelogram(st) + assert acg.min() >= -1.0 - 1e-5 + assert acg.max() <= 1.0 + 1e-5 + + def test_empty_train(self): + acg = _autocorrelogram(np.array([]), n_bins=100) + assert acg.shape == (100,) + assert (acg == 0).all() + + +# --------------------------------------------------------------------------- +# extract_features +# --------------------------------------------------------------------------- + +class TestExtractFeatures: + def test_shapes(self): + sd = _make_spikedata(n_units=8) + feats = extract_features(sd) + assert feats["wave"].shape == (8, 50) + assert feats["isi"].shape == (8, 100) + assert feats["acg"].shape == (8, 100) + + def test_dtype(self): + sd = _make_spikedata(n_units=5) + feats = extract_features(sd) + for arr in feats.values(): + assert arr.dtype == np.float32 + + def test_no_waveform_raises(self): + from spikelab.spikedata import SpikeData + trains = [_make_spike_train(100, seed=i) for i in range(3)] + sd = SpikeData(trains, length=60.0) + with pytest.raises(ValueError, match="avg_waveform"): + extract_features(sd) + + +# --------------------------------------------------------------------------- +# classify_neurons — mocked end-to-end +# --------------------------------------------------------------------------- + +class TestClassifyNeurons: + """Full pipeline test with the HuggingFace download and HIPPIE model mocked out.""" + + def _make_mock_classifier(self, n_neurons, z_dim=30): + mock_clf = MagicMock() + mock_clf.get_embeddings.return_value = np.random.randn(n_neurons, z_dim).astype(np.float32) + mock_clf.umap_reduce.return_value = np.random.randn(n_neurons, 2).astype(np.float32) + mock_clf.hdbscan_cluster.return_value = np.zeros(n_neurons, dtype=np.int32) + return mock_clf + + @patch("hippie.inference.HIPPIEClassifier") + def test_returns_all_keys_by_default(self, MockCls): + n = 10 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd) + assert "embeddings" in result + assert "umap_coords" in result + assert "cluster_labels" in result + + @patch("hippie.inference.HIPPIEClassifier") + def test_no_umap_no_hdbscan(self, MockCls): + n = 6 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd, run_umap=False, run_hdbscan=False) + assert "embeddings" in result + assert "umap_coords" not in result + assert "cluster_labels" not in result + + @patch("hippie.inference.HIPPIEClassifier") + def test_embedding_shape(self, MockCls): + n = 12 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n, z_dim=30) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd) + assert result["embeddings"].shape == (n, 30) + + @patch("hippie.inference.HIPPIEClassifier") + def test_tech_id_string(self, MockCls): + n = 5 + mock_clf = self._make_mock_classifier(n) + MockCls.from_pretrained.return_value = mock_clf + sd = _make_spikedata(n_units=n) + classify_neurons(sd, tech_id="silicon_probe") + mock_clf.get_embeddings.assert_called_once() + call_kwargs = mock_clf.get_embeddings.call_args + assert call_kwargs.kwargs.get("tech_id") == "silicon_probe" From faba91380ee43d374cfaf6b5ad01b230fe08568a Mon Sep 17 00:00:00 2001 From: JesusGF1 Date: Sun, 26 Apr 2026 19:32:16 -0700 Subject: [PATCH 2/8] Add unconditioned VAE training and compression pipeline Extends the HIPPIE integration with a self-supervised data compression path for users without cell-type or technology labels. Changes: - hippie_adapter.py: train_vae_on_spikedata() and compress_neurons() - analysis.py: train_vae_hippie() and compress_neurons_hippie() MCP tools - server.py: tool schemas and dispatch entries for both new tools - hippie_classification.rst: new "Unsupervised VAE compression" section with Python, direct-API, and MCP agent examples and a comparison table against the pretrained classifier Co-Authored-By: Claude Sonnet 4.6 --- docs/source/guides/hippie_classification.rst | 114 +++++++++++++++ src/spikelab/mcp_server/server.py | 120 +++++++++++++++- src/spikelab/mcp_server/tools/analysis.py | 137 +++++++++++++++++++ src/spikelab/spikedata/hippie_adapter.py | 123 +++++++++++++++++ 4 files changed, 493 insertions(+), 1 deletion(-) diff --git a/docs/source/guides/hippie_classification.rst b/docs/source/guides/hippie_classification.rst index e9139180..a7bbe0f7 100644 --- a/docs/source/guides/hippie_classification.rst +++ b/docs/source/guides/hippie_classification.rst @@ -259,6 +259,120 @@ families. Pass the matching ``tech_id`` for best results: * - ``3`` / ``"tetrodes"`` - Tetrode recordings +Unsupervised VAE compression (no conditioning) +---------------------------------------------- + +If you do not have cell-type or technology labels, or simply want to learn +a compressed representation of *your own dataset* from scratch, the +unconditioned VAE pipeline trains the same ResNet18 + fusion encoder +architecture as the pretrained model but with all conditioning removed. +The only training signal is reconstruction + KL divergence (beta-VAE ELBO, +beta=1) — no class embeddings, no technology embeddings. + +Results are stored as ``vae_embedding``, ``vae_umap_x``, ``vae_umap_y``, +and ``vae_cluster`` in ``neuron_attributes``, keeping them separate from the +pretrained-model attributes (``hippie_embedding`` etc.). + +Train and compress in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import ( + train_vae_on_spikedata, + compress_neurons, + ) + + # Step 1 — train on your data (requires avg_waveform in neuron_attributes) + compressor = train_vae_on_spikedata( + sd, + output_dir="./my_vae", + z_dim=30, # latent dimensionality + n_epochs=100, + batch_size=256, + device="cpu", # or "cuda" + ) + # Checkpoint saved to ./my_vae/vae_best.ckpt + + # Step 2 — compress (can reuse the returned compressor, or reload later) + result = compress_neurons(sd, compressor, run_umap=True, run_hdbscan=True) + + sd.set_neuron_attribute("vae_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("vae_embedding", result["embeddings"]) + +Reload and compress new data later +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import compress_neurons + + # Load a previously trained checkpoint by path + result = compress_neurons(sd_new, "./my_vae/vae_best.ckpt") + +Use the VAE API directly +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from hippie.vae import train_vae, VAECompressor + + # Train + compressor = train_vae(wave, isi, acg, output_dir="./my_vae", z_dim=30, n_epochs=100) + + # Or load an existing checkpoint + compressor = VAECompressor.from_checkpoint("./my_vae/vae_best.ckpt") + + # Encode + embeddings = compressor.get_embeddings(wave, isi, acg) # (N, z_dim) + coords = compressor.umap_reduce(embeddings) + labels = compressor.hdbscan_cluster(coords, min_cluster_size=5) + +MCP / agent +~~~~~~~~~~~ + +.. code-block:: text + + 1. "Train an unconditioned VAE on the neurons in namespace 'probe0', + saving to ./my_vae, with 50 epochs." + → calls train_vae_hippie + + 2. "Compress the neurons in 'probe0' using the checkpoint at ./my_vae/vae_best.ckpt." + → calls compress_neurons_hippie + + 3. "How many clusters did the VAE find?" + +How it differs from the pretrained classifier +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 35 35 + + * - + - Pretrained HIPPIE (``classify_neurons``) + - Unconditioned VAE (``compress_neurons``) + * - Requires labels + - No (inference only) + - No (train & infer) + * - Requires ``tech_id`` + - Yes + - No + * - Trains on your data + - No + - Yes + * - Learns from 11 datasets + - Yes (pretrained) + - No (your data only) + * - Latent space shaped by + - Cell types + technology + - Reconstruction only + * - Best for + - Known-technology recordings + - Exploratory compression, novel datasets + Advanced options ---------------- diff --git a/src/spikelab/mcp_server/server.py b/src/spikelab/mcp_server/server.py index 35b24b3f..a8750f3d 100644 --- a/src/spikelab/mcp_server/server.py +++ b/src/spikelab/mcp_server/server.py @@ -3079,6 +3079,121 @@ async def _list_tools() -> list[types.Tool]: ] ) + # ----------------------------------------------------------------------- + # Unconditioned VAE: training + compression (requires spikelab[hippie]) + # ----------------------------------------------------------------------- + tools.extend( + [ + types.Tool( + name="train_vae_hippie", + description=( + "Train an unconditioned multimodal VAE on a SpikeData object " + "(requires spikelab[hippie]). " + "Uses the same ResNet18 + fusion encoder as the pretrained HIPPIE " + "model but removes all class and technology conditioning — the VAE " + "learns to compress waveform + ISI + autocorrelogram into a latent " + "space using only reconstruction + KL loss. " + "Saves the best checkpoint to output_dir/vae_best.ckpt. " + "Pass that path to compress_neurons_hippie to encode new data. " + "Requires avg_waveform in neuron_attributes — run get_waveform_traces first." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "output_dir": { + "type": "string", + "description": "Directory to save the best checkpoint (vae_best.ckpt)", + }, + "z_dim": { + "type": "integer", + "description": "Latent space dimensionality (default 30)", + "default": 30, + }, + "n_epochs": { + "type": "integer", + "description": "Number of training epochs (default 100)", + "default": 100, + }, + "batch_size": { + "type": "integer", + "description": "Minibatch size (default 256)", + "default": 256, + }, + "learning_rate": { + "type": "number", + "description": "AdamW learning rate (default 1e-3)", + "default": 0.001, + }, + "val_fraction": { + "type": "number", + "description": "Fraction of neurons held out for validation (default 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device: 'cpu' or 'cuda'", + "default": "cpu", + }, + }, + "required": ["workspace_id", "namespace", "output_dir"], + }, + ), + types.Tool( + name="compress_neurons_hippie", + description=( + "Compress neurons using a trained unconditioned VAE " + "(requires spikelab[hippie]). " + "Encodes all neurons into the VAE latent space, optionally runs " + "UMAP (2-D projection) and HDBSCAN clustering, then stores results " + "as neuron_attributes: vae_embedding, vae_umap_x, vae_umap_y, " + "vae_cluster. Train the model first with train_vae_hippie." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "checkpoint_path": { + "type": "string", + "description": "Path to the .ckpt file saved by train_vae_hippie", + }, + "run_umap": { + "type": "boolean", + "description": "Compute 2-D UMAP projection (default true)", + "default": True, + }, + "run_hdbscan": { + "type": "boolean", + "description": "Cluster with HDBSCAN (default true)", + "default": True, + }, + "min_cluster_size": { + "type": "integer", + "description": "Minimum neurons per HDBSCAN cluster (default 5)", + "default": 5, + }, + "umap_n_neighbors": { + "type": "integer", + "description": "UMAP neighbourhood size (default 30)", + "default": 30, + }, + "umap_min_dist": { + "type": "number", + "description": "UMAP minimum distance (default 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device: 'cpu' or 'cuda'", + "default": "cpu", + }, + }, + "required": ["workspace_id", "namespace", "checkpoint_path"], + }, + ), + ] + ) + # ----------------------------------------------------------------------- # Workspace management tools # ----------------------------------------------------------------------- @@ -4264,6 +4379,7 @@ async def _list_tools() -> list[types.Tool]: "load_workspace_item": analysis.load_workspace_item, "merge_workspace": analysis.merge_workspace, "fetch_workspace_item": analysis.fetch_workspace_item, +<<<<<<< HEAD # Shuffling and stack builders "spike_shuffle": analysis.spike_shuffle, "spike_shuffle_stack": analysis.spike_shuffle_stack, @@ -4289,8 +4405,10 @@ async def _list_tools() -> list[types.Tool]: "slice_trend": analysis.slice_trend, "slice_stability": analysis.slice_stability, "pairwise_tests": analysis.pairwise_tests, - # HIPPIE cell-type classification + # HIPPIE cell-type classification and VAE compression "classify_neurons_hippie": analysis.classify_neurons_hippie, + "train_vae_hippie": analysis.train_vae_hippie, + "compress_neurons_hippie": analysis.compress_neurons_hippie, # Export tools "export_to_hdf5_raster": exporters.export_to_hdf5_raster, "export_to_hdf5_ragged": exporters.export_to_hdf5_ragged, diff --git a/src/spikelab/mcp_server/tools/analysis.py b/src/spikelab/mcp_server/tools/analysis.py index f5e19b50..65868201 100644 --- a/src/spikelab/mcp_server/tools/analysis.py +++ b/src/spikelab/mcp_server/tools/analysis.py @@ -3161,3 +3161,140 @@ async def classify_neurons_hippie( "n_noise_neurons": n_noise, "neuron_attributes_added": added_attrs, } + + +# --------------------------------------------------------------------------- +# Unconditioned VAE: training + compression (requires spikelab[hippie]) +# --------------------------------------------------------------------------- + + +async def train_vae_hippie( + workspace_id: str, + namespace: str, + output_dir: str, + z_dim: int = 30, + n_epochs: int = 100, + batch_size: int = 256, + learning_rate: float = 1e-3, + val_fraction: float = 0.1, + device: str = "cpu", +) -> Dict[str, Any]: + """Train an unconditioned multimodal VAE on a SpikeData object (requires spikelab[hippie]). + + Uses the same ResNet18 + fusion encoder architecture as the pretrained HIPPIE + model but removes all class and technology conditioning. The VAE learns to + compress waveform + ISI + autocorrelogram into a z_dim-dimensional latent + space using only reconstruction + KL loss (beta=1). + + The best checkpoint is saved to output_dir/vae_best.ckpt. Pass this path + to compress_neurons_hippie to encode new data. + + Requires avg_waveform in neuron_attributes — run get_waveform_traces first. + """ + from ....spikedata.hippie_adapter import train_vae_on_spikedata + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + train_vae_on_spikedata( + sd, + output_dir=output_dir, + z_dim=z_dim, + n_epochs=n_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + val_fraction=val_fraction, + device=device, + ) + + import os + ckpt_path = os.path.join(output_dir, "vae_best.ckpt") + return { + "workspace_id": workspace_id, + "namespace": namespace, + "checkpoint_path": ckpt_path, + "z_dim": z_dim, + "n_epochs": n_epochs, + "n_neurons_trained_on": sd.N, + } + + +async def compress_neurons_hippie( + workspace_id: str, + namespace: str, + checkpoint_path: str, + run_umap: bool = True, + run_hdbscan: bool = True, + min_cluster_size: int = 5, + umap_n_neighbors: int = 30, + umap_min_dist: float = 0.1, + device: str = "cpu", +) -> Dict[str, Any]: + """Compress neurons with a trained unconditioned VAE (requires spikelab[hippie]). + + Encodes all neurons into the VAE latent space, optionally runs UMAP and + HDBSCAN, then writes results into neuron_attributes: + vae_embedding, vae_umap_x, vae_umap_y, vae_cluster. + + Args: + workspace_id: Workspace ID. + namespace: Recording namespace. + checkpoint_path: Path to the .ckpt file saved by train_vae_hippie. + run_umap: Compute 2-D UMAP projection. + run_hdbscan: Cluster with HDBSCAN (-1 = noise). + min_cluster_size: Minimum neurons per cluster. + umap_n_neighbors: UMAP neighbourhood size. + umap_min_dist: UMAP minimum distance. + device: "cuda" or "cpu". + """ + from ....spikedata.hippie_adapter import compress_neurons + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + result = compress_neurons( + sd, + compressor=checkpoint_path, + run_umap=run_umap, + run_hdbscan=run_hdbscan, + umap_kwargs={"n_neighbors": umap_n_neighbors, "min_dist": umap_min_dist}, + hdbscan_kwargs={"min_cluster_size": min_cluster_size}, + device=device, + ) + + sd.set_neuron_attribute("vae_embedding", result["embeddings"].tolist()) + if "umap_coords" in result: + sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0].tolist()) + sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1].tolist()) + if "cluster_labels" in result: + sd.set_neuron_attribute("vae_cluster", result["cluster_labels"].tolist()) + + ws.store(namespace, "spikedata", sd) + + n_clusters = ( + int(np.unique(result["cluster_labels"][result["cluster_labels"] >= 0]).size) + if "cluster_labels" in result + else None + ) + + added_attrs = ["vae_embedding"] + if "umap_coords" in result: + added_attrs += ["vae_umap_x", "vae_umap_y"] + if "cluster_labels" in result: + added_attrs.append("vae_cluster") + + return { + "workspace_id": workspace_id, + "namespace": namespace, + "n_neurons": int(result["embeddings"].shape[0]), + "embedding_dim": int(result["embeddings"].shape[1]), + "umap_computed": "umap_coords" in result, + "hdbscan_computed": "cluster_labels" in result, + "n_clusters": n_clusters, + "n_noise_neurons": ( + int((result["cluster_labels"] < 0).sum()) + if "cluster_labels" in result + else None + ), + "neuron_attributes_added": added_attrs, + } diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py index 39a902e8..42f8b0e0 100644 --- a/src/spikelab/spikedata/hippie_adapter.py +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -9,6 +9,7 @@ from __future__ import annotations +from pathlib import Path from typing import Optional, Union import numpy as np @@ -219,3 +220,125 @@ def classify_neurons( ) return result + + +# --------------------------------------------------------------------------- +# Unconditioned VAE: training + compression +# --------------------------------------------------------------------------- + +def train_vae_on_spikedata( + sd, + output_dir: str, + z_dim: int = 30, + n_epochs: int = 100, + batch_size: int = 256, + learning_rate: float = 1e-3, + weight_decay: float = 1e-2, + val_fraction: float = 0.1, + device: str = "cpu", + random_state: int = 42, +) -> "VAECompressor": + """Train an unconditioned multimodal VAE on a SpikeData object. + + Extracts waveform, ISI, and ACG features from sd, then trains a VAE + (no class or technology conditioning) to compress the data. The best + checkpoint is saved to output_dir and a ready VAECompressor is returned. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + output_dir: Directory to save the best checkpoint (vae_best.ckpt). + z_dim: Latent space dimensionality (default 30). + n_epochs: Training epochs. + batch_size: Minibatch size. + learning_rate: AdamW learning rate. + weight_decay: AdamW weight decay. + val_fraction: Fraction of neurons held out for validation. + device: "cuda" or "cpu". + random_state: Reproducibility seed. + + Returns: + VAECompressor loaded from the best checkpoint. + + Example: + >>> from spikelab.spikedata.hippie_adapter import train_vae_on_spikedata + >>> compressor = train_vae_on_spikedata(sd, output_dir="./my_vae", n_epochs=50) + >>> result = compress_neurons(sd, compressor) + """ + _require_hippie() + from hippie.vae import train_vae + + features = extract_features(sd) + return train_vae( + wave=features["wave"], + isi=features["isi"], + acg=features["acg"], + output_dir=output_dir, + z_dim=z_dim, + n_epochs=n_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + val_fraction=val_fraction, + device=device, + random_state=random_state, + ) + + +def compress_neurons( + sd, + compressor: Union[str, "VAECompressor"], + run_umap: bool = True, + run_hdbscan: bool = True, + umap_kwargs: Optional[dict] = None, + hdbscan_kwargs: Optional[dict] = None, + batch_size: int = 256, + device: str = "cpu", +) -> dict: + """Compress neurons with a trained unconditioned VAE. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + compressor: A VAECompressor instance or a path to a checkpoint (.ckpt). + run_umap: Compute 2-D UMAP projection of the embeddings. + run_hdbscan: Cluster with HDBSCAN on UMAP coords (or raw embeddings). + umap_kwargs: Extra kwargs forwarded to VAECompressor.umap_reduce(). + hdbscan_kwargs: Extra kwargs forwarded to VAECompressor.hdbscan_cluster(). + batch_size: Neurons per forward pass. + device: "cuda" or "cpu" (only used when loading from a checkpoint path). + + Returns: + dict with keys: + - "embeddings": (N, z_dim) latent z_mean vectors + - "umap_coords": (N, 2) UMAP coordinates (if run_umap=True) + - "cluster_labels":(N,) HDBSCAN labels, -1=noise (if run_hdbscan=True) + + Example: + >>> result = compress_neurons(sd, "./my_vae/vae_best.ckpt") + >>> sd.set_neuron_attribute("vae_cluster", result["cluster_labels"]) + >>> sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0]) + >>> sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1]) + >>> sd.set_neuron_attribute("vae_embedding", result["embeddings"]) + """ + _require_hippie() + from hippie.vae import VAECompressor + + if isinstance(compressor, (str, Path)): + compressor = VAECompressor.from_checkpoint(compressor, device=device) + + features = extract_features(sd) + embeddings = compressor.get_embeddings( + features["wave"], features["isi"], features["acg"], batch_size=batch_size + ) + + result: dict = {"embeddings": embeddings} + + if run_umap: + result["umap_coords"] = compressor.umap_reduce(embeddings, **(umap_kwargs or {})) + + if run_hdbscan: + cluster_input = result.get("umap_coords", embeddings) + result["cluster_labels"] = compressor.hdbscan_cluster( + cluster_input, **(hdbscan_kwargs or {}) + ) + + return result From 3be618c60e7daaa504cdc3041e0f26125de1dfbf Mon Sep 17 00:00:00 2001 From: JesusGF1 Date: Mon, 27 Apr 2026 09:48:50 -0700 Subject: [PATCH 3/8] Apply Black formatting to HIPPIE adapter, analysis tools, and tests Co-Authored-By: Claude Sonnet 4.6 --- src/spikelab/mcp_server/tools/analysis.py | 39 ++++++++++++++--------- src/spikelab/spikedata/hippie_adapter.py | 21 ++++++++---- tests/test_hippie_adapter.py | 13 ++++++-- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/src/spikelab/mcp_server/tools/analysis.py b/src/spikelab/mcp_server/tools/analysis.py index 65868201..5375ea64 100644 --- a/src/spikelab/mcp_server/tools/analysis.py +++ b/src/spikelab/mcp_server/tools/analysis.py @@ -451,14 +451,18 @@ async def compute_spike_trig_pop_rate( """Compute spike-triggered population rate and coupling stats and store to workspace.""" ws = _get_workspace(workspace_id) sd = _get_spikedata(ws, namespace) - stPR_filtered, coupling_zero_lag, coupling_max, delays, lags = ( - sd.compute_spike_trig_pop_rate( - window_ms=window_ms, - cutoff_hz=cutoff_hz, - fs=fs, - bin_size=bin_size, - cut_outer=cut_outer, - ) + ( + stPR_filtered, + coupling_zero_lag, + coupling_max, + delays, + lags, + ) = sd.compute_spike_trig_pop_rate( + window_ms=window_ms, + cutoff_hz=cutoff_hz, + fs=fs, + bin_size=bin_size, + cut_outer=cut_outer, ) # Store stPR (U, T) and lags (T,) separately; combine coupling stats as (3, U) coupling_stack = np.stack( @@ -1307,13 +1311,17 @@ async def compute_rate_slice_unit_order( ws = _get_workspace(workspace_id) rss = _get_rateslicestack(ws, namespace, stack_key) frac_active = _get_optional_frac_active(ws, namespace, frac_active_key) - _, unit_ids_in_order, unit_std_indices, unit_peak_times, unit_frac_active = ( - rss.order_units_across_slices( - agg_func, - MIN_RATE_THRESHOLD=min_rate_threshold, - MIN_FRAC_ACTIVE=min_frac_active, - frac_active=frac_active, - ) + ( + _, + unit_ids_in_order, + unit_std_indices, + unit_peak_times, + unit_frac_active, + ) = rss.order_units_across_slices( + agg_func, + MIN_RATE_THRESHOLD=min_rate_threshold, + MIN_FRAC_ACTIVE=min_frac_active, + frac_active=frac_active, ) # Each element is a tuple of two arrays (highly_active, low_active) return { @@ -3208,6 +3216,7 @@ async def train_vae_hippie( ) import os + ckpt_path = os.path.join(output_dir, "vae_best.ckpt") return { "workspace_id": workspace_id, diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py index 42f8b0e0..87f0bde7 100644 --- a/src/spikelab/spikedata/hippie_adapter.py +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -39,6 +39,7 @@ def _require_hippie(): # Per-neuron preprocessing helpers # ------------------------------------------------------------------ + def _preprocess_waveform(wave: np.ndarray, target: int = 50) -> np.ndarray: """Resample waveform to target length and min-max normalize to [-1, 1].""" import torch @@ -57,7 +58,9 @@ def _isi_histogram(spike_times: np.ndarray, n_bins: int = _ISI_N_BINS) -> np.nda isis_ms = np.diff(np.sort(spike_times)) * 1000.0 isis_ms = isis_ms[isis_ms > 0] if len(isis_ms) < 2: - return np.full(n_bins, -1.0, dtype=np.float32) # return flat -1 for silent neurons + return np.full( + n_bins, -1.0, dtype=np.float32 + ) # return flat -1 for silent neurons bins = np.logspace(np.log10(_ISI_MIN_MS), np.log10(_ISI_MAX_MS), n_bins + 1) hist, _ = np.histogram(isis_ms, bins=bins, density=True) @@ -106,6 +109,7 @@ def _autocorrelogram( # Public API # ------------------------------------------------------------------ + def extract_features( sd, isi_bins: int = _ISI_N_BINS, @@ -139,10 +143,12 @@ def extract_features( wave_arr = np.stack([_preprocess_waveform(np.asarray(w)) for w in waves]) isi_arr = np.stack([_isi_histogram(t, n_bins=isi_bins) for t in sd.train]) - acg_arr = np.stack([ - _autocorrelogram(t, max_lag_ms=acg_max_lag_ms, n_bins=acg_bins) - for t in sd.train - ]) + acg_arr = np.stack( + [ + _autocorrelogram(t, max_lag_ms=acg_max_lag_ms, n_bins=acg_bins) + for t in sd.train + ] + ) return {"wave": wave_arr, "isi": isi_arr, "acg": acg_arr} @@ -226,6 +232,7 @@ def classify_neurons( # Unconditioned VAE: training + compression # --------------------------------------------------------------------------- + def train_vae_on_spikedata( sd, output_dir: str, @@ -333,7 +340,9 @@ def compress_neurons( result: dict = {"embeddings": embeddings} if run_umap: - result["umap_coords"] = compressor.umap_reduce(embeddings, **(umap_kwargs or {})) + result["umap_coords"] = compressor.umap_reduce( + embeddings, **(umap_kwargs or {}) + ) if run_hdbscan: cluster_input = result.get("umap_coords", embeddings) diff --git a/tests/test_hippie_adapter.py b/tests/test_hippie_adapter.py index 3276601a..6bffa30d 100644 --- a/tests/test_hippie_adapter.py +++ b/tests/test_hippie_adapter.py @@ -24,6 +24,7 @@ # Fixtures # --------------------------------------------------------------------------- + def _make_spike_train(n_spikes=200, duration_s=60.0, seed=0): rng = np.random.default_rng(seed) return np.sort(rng.uniform(0, duration_s, n_spikes)) @@ -50,6 +51,7 @@ def _make_spikedata(n_units=10, seed=0): # Unit tests for preprocessing helpers # --------------------------------------------------------------------------- + class TestPreprocessWaveform: def test_output_shape(self): wave = _make_waveform(82) @@ -109,6 +111,7 @@ def test_empty_train(self): # extract_features # --------------------------------------------------------------------------- + class TestExtractFeatures: def test_shapes(self): sd = _make_spikedata(n_units=8) @@ -125,6 +128,7 @@ def test_dtype(self): def test_no_waveform_raises(self): from spikelab.spikedata import SpikeData + trains = [_make_spike_train(100, seed=i) for i in range(3)] sd = SpikeData(trains, length=60.0) with pytest.raises(ValueError, match="avg_waveform"): @@ -135,13 +139,18 @@ def test_no_waveform_raises(self): # classify_neurons — mocked end-to-end # --------------------------------------------------------------------------- + class TestClassifyNeurons: """Full pipeline test with the HuggingFace download and HIPPIE model mocked out.""" def _make_mock_classifier(self, n_neurons, z_dim=30): mock_clf = MagicMock() - mock_clf.get_embeddings.return_value = np.random.randn(n_neurons, z_dim).astype(np.float32) - mock_clf.umap_reduce.return_value = np.random.randn(n_neurons, 2).astype(np.float32) + mock_clf.get_embeddings.return_value = np.random.randn(n_neurons, z_dim).astype( + np.float32 + ) + mock_clf.umap_reduce.return_value = np.random.randn(n_neurons, 2).astype( + np.float32 + ) mock_clf.hdbscan_cluster.return_value = np.zeros(n_neurons, dtype=np.int32) return mock_clf From 99e8486b22e1df3a4b0e9ceeebc26d1a22061dfe Mon Sep 17 00:00:00 2001 From: JesusGF1 Date: Mon, 27 Apr 2026 09:49:31 -0700 Subject: [PATCH 4/8] Add personal reference for HIPPIE agent prompts Documents copy-paste prompts for running pretrained HIPPIE classifier and unsupervised VAE compression via the SpikeLab MCP server across all three Neuropixels pipeline entry points. Co-Authored-By: Claude Sonnet 4.6 --- docs/source/guides/hippie_agent_prompts.md | 124 +++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 docs/source/guides/hippie_agent_prompts.md diff --git a/docs/source/guides/hippie_agent_prompts.md b/docs/source/guides/hippie_agent_prompts.md new file mode 100644 index 00000000..1d954bd7 --- /dev/null +++ b/docs/source/guides/hippie_agent_prompts.md @@ -0,0 +1,124 @@ +# HIPPIE Agent Prompts for SpikeLab + +Personal reference for running HIPPIE cell-type classification and VAE compression +via an AI agent (e.g., Claude Code) that has access to the SpikeLab MCP server. + +--- + +## 1. Pretrained HIPPIE Classifier + +### Pipeline A — Kilosort output + raw .bin file + +``` +I have Kilosort spike-sorting results at /path/to/kilosort_output/ and the +raw binary recording at /path/to/recording.bin with 384 channels sampled at +30,000 Hz. The electrode geometry is neuropixels. + +Using the SpikeLab MCP server: +1. Load the Kilosort output into a new workspace called "my_session". +2. Extract average waveforms from the raw binary for every unit + (window: 1 ms before to 2 ms after each spike, max 500 spikes per unit). +3. Classify neurons with the pretrained HIPPIE model + (tech_id="neuropixels", run_umap=True, run_hdbscan=True, min_cluster_size=5). +4. Return the cluster label counts and save embeddings + UMAP coordinates + back into the workspace under the "hippie" namespace. +``` + +### Pipeline B — NWB file + +``` +I have an NWB file at /path/to/session.nwb that contains spike times and +pre-sorted units with waveforms stored under processing/ecephys. + +Using the SpikeLab MCP server: +1. Load the NWB file into a workspace called "nwb_session". +2. Classify neurons with the pretrained HIPPIE model + (tech_id="neuropixels", run_umap=True, run_hdbscan=True). +3. Add hippie_cluster, hippie_embedding, hippie_umap_x, hippie_umap_y as + neuron attributes back to the workspace. +4. Show me the cluster label distribution. +``` + +### Pipeline C — Pre-computed waveforms already in SpikeData + +``` +I already have a SpikeData object saved at /path/to/session.pkl that includes +avg_waveform in neuron_attributes. + +Using the SpikeLab MCP server: +1. Load the pickle into workspace "precomputed". +2. Run classify_neurons_hippie with tech_id="neuropixels" and default UMAP/HDBSCAN settings. +3. Write results back to the workspace and show me the number of neurons per cluster. +``` + +--- + +## 2. Unsupervised VAE (train on your own data, no conditioning) + +### Train a new VAE on your recordings + +``` +I have a SpikeData object at workspace "my_session" namespace "sorted" with +avg_waveform in neuron_attributes. + +Using the SpikeLab MCP server: +1. Train an unsupervised multimodal VAE on the neurons in that workspace + using train_vae_hippie: + - output_dir: ./vae_checkpoints/my_session + - z_dim: 30 + - n_epochs: 100 + - batch_size: 256 + - val_fraction: 0.1 +2. Tell me the best validation loss achieved and the path to the saved checkpoint. +``` + +### Compress neurons with a trained checkpoint + +``` +I have a trained VAE checkpoint at ./vae_checkpoints/my_session/vae_best.ckpt +and a SpikeData object in workspace "new_session" namespace "sorted". + +Using the SpikeLab MCP server: +1. Run compress_neurons_hippie with that checkpoint on the workspace. +2. Use run_umap=True and run_hdbscan=True with min_cluster_size=5. +3. Write vae_embedding, vae_umap_x, vae_umap_y, vae_cluster into the workspace. +4. Return the cluster label counts. +``` + +--- + +## 3. Run Both Models and Compare + +``` +I have a SpikeData workspace "comparison_session" namespace "sorted" with +avg_waveform in neuron_attributes (neuropixels recording). + +Using the SpikeLab MCP server, run both classification pipelines and compare: + +Step 1 — Pretrained HIPPIE: + classify_neurons_hippie(tech_id="neuropixels", run_umap=True, run_hdbscan=True, + min_cluster_size=5) + +Step 2 — Unsupervised VAE (train from scratch): + train_vae_hippie(output_dir="./vae_ckpt", z_dim=30, n_epochs=100) + compress_neurons_hippie(checkpoint_path="./vae_ckpt/vae_best.ckpt", + run_umap=True, run_hdbscan=True, min_cluster_size=5) + +Step 3 — Report: + - Number of HIPPIE clusters vs VAE clusters + - Number of noise-labeled neurons (-1) in each + - Suggest which to use for downstream analysis given the cluster coherence +``` + +--- + +## 4. Key Notes + +| Item | Detail | +|------|--------| +| `tech_id` options | `"neuropixels"` (0), `"silicon_probe"` (1), `"juxtacellular"` (2), `"tetrodes"` (3) | +| HIPPIE checkpoint | Auto-downloaded from `Jesusgf23/hippie` on HuggingFace (~293 MB, cached after first use) | +| VAE checkpoint | Saved locally to `output_dir/vae_best.ckpt` | +| `avg_waveform` required | Must be present in `neuron_attributes` before calling either pipeline | +| MCP per-unit limitation | `get_waveform_traces` MCP tool is per-unit; for bulk waveform extraction ask the agent to loop over units or use the Python API directly | +| HIPPIE install | `pip install spikelab[hippie]` before running any of the above | From 24ff76450f4750bfca6750f48a745bf1865d296e Mon Sep 17 00:00:00 2001 From: JesusGF1 Date: Mon, 27 Apr 2026 12:47:58 -0700 Subject: [PATCH 5/8] Fix NWB loader: clear partial trains when pynwb fallback triggers When pynwb appended one or more spike trains before raising an exception, the h5py fallback path would add its own full set of trains on top, causing a train/attribute count mismatch. Clear both lists in the except branch so h5py starts from a clean slate. Co-Authored-By: Claude Sonnet 4.6 --- src/spikelab/data_loaders/data_loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikelab/data_loaders/data_loaders.py b/src/spikelab/data_loaders/data_loaders.py index e4835512..e9b50042 100644 --- a/src/spikelab/data_loaders/data_loaders.py +++ b/src/spikelab/data_loaders/data_loaders.py @@ -752,6 +752,8 @@ def load_spikedata_from_nwb( f"format or report a bug.", stacklevel=2, ) + trains.clear() + neuron_attributes.clear() ensure_h5py() with h5py.File(filepath, "r") as f: # type: ignore From ec261367cfffad0a15227ca3e09d350a0e7e375d Mon Sep 17 00:00:00 2001 From: TjitsevdM Date: Fri, 29 May 2026 06:38:46 -0700 Subject: [PATCH 6/8] Apply code-review feedback to HIPPIE adapter PR - Fix critical spike-time unit bug in _isi_histogram and _autocorrelogram. SpikeLab convention is milliseconds; the helpers were multiplying by 1000 as if inputs were seconds, pushing ISIs above the 5000 ms upper bin and collapsing ACGs to all zeros. Drop the multiplications and document the unit on both helpers. - Reseat tests/test_hippie_adapter.py fixtures to ms (60_000 instead of 60) and add positive-content regression guards in TestISIHistogram and TestAutocorrelogram that fail if the histograms collapse to flat output (the failure mode the unit bug produced). - extract_features now calls _require_hippie() so users without the [hippie] extra see the install message rather than a raw torch ModuleNotFoundError. - train_vae_hippie probes output_dir writability with mkdirs + a .write_probe round-trip before training starts; fails fast on permission errors instead of after epochs of work. - Extract _store_hippie_result helper to remove the duplicated embedding/UMAP/cluster-storage block between classify_neurons_hippie and compress_neurons_hippie. - Add __all__ to hippie_adapter and re-export from spikedata/__init__.py so import discovery follows the rest of the subpackage. - Add tests/test_dataloaders.py regression test that locks in the NWB partial-pynwb-failure clearing fix: mocks pynwb to populate 2 trains then raise mid-loop, asserts SpikeData.N == 2 (the bug produces 4). --- src/spikelab/mcp_server/tools/analysis.py | 134 +++++++++------------- src/spikelab/spikedata/__init__.py | 1 + src/spikelab/spikedata/hippie_adapter.py | 28 +++-- tests/test_dataloaders.py | 74 ++++++++++++ tests/test_hippie_adapter.py | 32 +++++- 5 files changed, 179 insertions(+), 90 deletions(-) diff --git a/src/spikelab/mcp_server/tools/analysis.py b/src/spikelab/mcp_server/tools/analysis.py index 5375ea64..9182cdd9 100644 --- a/src/spikelab/mcp_server/tools/analysis.py +++ b/src/spikelab/mcp_server/tools/analysis.py @@ -3076,6 +3076,48 @@ async def pairwise_tests( # --------------------------------------------------------------------------- +def _store_hippie_result(ws, sd, workspace_id, namespace, result, prefix): + """Write embeddings / UMAP / cluster labels to neuron_attributes and return summary. + + Shared by ``classify_neurons_hippie`` and ``compress_neurons_hippie``; differs + only in the attribute-name prefix (``"hippie"`` vs ``"vae"``). Re-stores the + SpikeData in the workspace so downstream tools see the new attributes. + """ + sd.set_neuron_attribute(f"{prefix}_embedding", result["embeddings"].tolist()) + added_attrs = [f"{prefix}_embedding"] + if "umap_coords" in result: + sd.set_neuron_attribute( + f"{prefix}_umap_x", result["umap_coords"][:, 0].tolist() + ) + sd.set_neuron_attribute( + f"{prefix}_umap_y", result["umap_coords"][:, 1].tolist() + ) + added_attrs += [f"{prefix}_umap_x", f"{prefix}_umap_y"] + if "cluster_labels" in result: + sd.set_neuron_attribute(f"{prefix}_cluster", result["cluster_labels"].tolist()) + added_attrs.append(f"{prefix}_cluster") + + ws.store(namespace, "spikedata", sd) + + labels = result.get("cluster_labels") + n_clusters = ( + int(np.unique(labels[labels >= 0]).size) if labels is not None else None + ) + n_noise = int((labels < 0).sum()) if labels is not None else None + + return { + "workspace_id": workspace_id, + "namespace": namespace, + "n_neurons": int(result["embeddings"].shape[0]), + "embedding_dim": int(result["embeddings"].shape[1]), + "umap_computed": "umap_coords" in result, + "hdbscan_computed": "cluster_labels" in result, + "n_clusters": n_clusters, + "n_noise_neurons": n_noise, + "neuron_attributes_added": added_attrs, + } + + async def classify_neurons_hippie( workspace_id: str, namespace: str, @@ -3130,45 +3172,7 @@ async def classify_neurons_hippie( cache_dir=cache_dir, ) - # Store results as neuron_attributes - sd.set_neuron_attribute("hippie_embedding", result["embeddings"].tolist()) - if "umap_coords" in result: - sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0].tolist()) - sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1].tolist()) - if "cluster_labels" in result: - sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"].tolist()) - - # Persist the updated SpikeData - ws.store(namespace, "spikedata", sd) - - n_clusters = ( - int(np.unique(result["cluster_labels"][result["cluster_labels"] >= 0]).size) - if "cluster_labels" in result - else None - ) - n_noise = ( - int((result["cluster_labels"] < 0).sum()) - if "cluster_labels" in result - else None - ) - - added_attrs = ["hippie_embedding"] - if "umap_coords" in result: - added_attrs += ["hippie_umap_x", "hippie_umap_y"] - if "cluster_labels" in result: - added_attrs.append("hippie_cluster") - - return { - "workspace_id": workspace_id, - "namespace": namespace, - "n_neurons": int(result["embeddings"].shape[0]), - "embedding_dim": int(result["embeddings"].shape[1]), - "umap_computed": "umap_coords" in result, - "hdbscan_computed": "cluster_labels" in result, - "n_clusters": n_clusters, - "n_noise_neurons": n_noise, - "neuron_attributes_added": added_attrs, - } + return _store_hippie_result(ws, sd, workspace_id, namespace, result, "hippie") # --------------------------------------------------------------------------- @@ -3199,11 +3203,24 @@ async def train_vae_hippie( Requires avg_waveform in neuron_attributes — run get_waveform_traces first. """ + import os + from ....spikedata.hippie_adapter import train_vae_on_spikedata ws = _get_workspace(workspace_id) sd = _get_spikedata(ws, namespace) + # Fail fast if output_dir is unwritable — VAE training can take hours, so + # surface permission / typo errors before the run starts rather than after. + os.makedirs(output_dir, exist_ok=True) + probe = os.path.join(output_dir, ".write_probe") + try: + with open(probe, "w") as fh: + fh.write("") + os.remove(probe) + except OSError as e: + raise OSError(f"output_dir is not writable: {output_dir!r} ({e})") from e + train_vae_on_spikedata( sd, output_dir=output_dir, @@ -3215,8 +3232,6 @@ async def train_vae_hippie( device=device, ) - import os - ckpt_path = os.path.join(output_dir, "vae_best.ckpt") return { "workspace_id": workspace_id, @@ -3271,39 +3286,4 @@ async def compress_neurons_hippie( device=device, ) - sd.set_neuron_attribute("vae_embedding", result["embeddings"].tolist()) - if "umap_coords" in result: - sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0].tolist()) - sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1].tolist()) - if "cluster_labels" in result: - sd.set_neuron_attribute("vae_cluster", result["cluster_labels"].tolist()) - - ws.store(namespace, "spikedata", sd) - - n_clusters = ( - int(np.unique(result["cluster_labels"][result["cluster_labels"] >= 0]).size) - if "cluster_labels" in result - else None - ) - - added_attrs = ["vae_embedding"] - if "umap_coords" in result: - added_attrs += ["vae_umap_x", "vae_umap_y"] - if "cluster_labels" in result: - added_attrs.append("vae_cluster") - - return { - "workspace_id": workspace_id, - "namespace": namespace, - "n_neurons": int(result["embeddings"].shape[0]), - "embedding_dim": int(result["embeddings"].shape[1]), - "umap_computed": "umap_coords" in result, - "hdbscan_computed": "cluster_labels" in result, - "n_clusters": n_clusters, - "n_noise_neurons": ( - int((result["cluster_labels"] < 0).sum()) - if "cluster_labels" in result - else None - ), - "neuron_attributes_added": added_attrs, - } + return _store_hippie_result(ws, sd, workspace_id, namespace, result, "vae") diff --git a/src/spikelab/spikedata/__init__.py b/src/spikelab/spikedata/__init__.py index be273f16..ed9265a1 100644 --- a/src/spikelab/spikedata/__init__.py +++ b/src/spikelab/spikedata/__init__.py @@ -4,3 +4,4 @@ from .ratedata import * from .rateslicestack import * from .spikeslicestack import * +from .hippie_adapter import * # noqa: F401,F403 — module is import-safe without [hippie] diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py index 87f0bde7..b1d85d99 100644 --- a/src/spikelab/spikedata/hippie_adapter.py +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -14,6 +14,13 @@ import numpy as np +__all__ = [ + "extract_features", + "classify_neurons", + "train_vae_on_spikedata", + "compress_neurons", +] + # ISI histogram parameters (log-spaced, matching HIPPIE training data convention) _ISI_N_BINS = 100 @@ -54,8 +61,11 @@ def _preprocess_waveform(wave: np.ndarray, target: int = 50) -> np.ndarray: def _isi_histogram(spike_times: np.ndarray, n_bins: int = _ISI_N_BINS) -> np.ndarray: - """Compute a log-spaced ISI histogram, log(x+1)-transformed and min-max normalized.""" - isis_ms = np.diff(np.sort(spike_times)) * 1000.0 + """Compute a log-spaced ISI histogram, log(x+1)-transformed and min-max normalized. + + Spike times must be in milliseconds (SpikeLab convention). + """ + isis_ms = np.diff(np.sort(spike_times)) isis_ms = isis_ms[isis_ms > 0] if len(isis_ms) < 2: return np.full( @@ -79,19 +89,22 @@ def _autocorrelogram( max_lag_ms: float = _ACG_MAX_LAG_MS, n_bins: int = _ACG_N_BINS, ) -> np.ndarray: - """Compute a half-sided autocorrelogram (forward lags only), min-max normalized.""" + """Compute a half-sided autocorrelogram (forward lags only), min-max normalized. + + Spike times must be in milliseconds (SpikeLab convention). + """ if len(spike_times) < 2: return np.zeros(n_bins, dtype=np.float32) - st_ms = np.sort(spike_times) * 1000.0 + st = np.sort(spike_times) bin_edges = np.linspace(0.0, max_lag_ms, n_bins + 1) counts = np.zeros(n_bins, dtype=np.float64) - for i in range(len(st_ms)): - hi = np.searchsorted(st_ms, st_ms[i] + max_lag_ms, side="right") + for i in range(len(st)): + hi = np.searchsorted(st, st[i] + max_lag_ms, side="right") lo = i + 1 if lo < hi: - diffs = st_ms[lo:hi] - st_ms[i] + diffs = st[lo:hi] - st[i] counts += np.histogram(diffs, bins=bin_edges)[0] total = counts.sum() @@ -134,6 +147,7 @@ def extract_features( - "isi": (N, 100) log-transformed, normalized ISI histograms - "acg": (N, 100) normalized autocorrelograms """ + _require_hippie() waves = sd.get_neuron_attribute("avg_waveform") if waves is None or any(w is None for w in waves): raise ValueError( diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index e2cd660d..bba55393 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -5601,6 +5601,80 @@ def __init__(self, *args, **kwargs): self._assert_h5py_fallback_warning(recwarn, OSError) assert sd.N == 2 + def test_pynwb_mid_loop_failure_does_not_leak_into_h5py_fallback( + self, tmp_path, monkeypatch + ): + """ + Regression for the partial-population bug in the pynwb→h5py fallback. + + When pynwb starts successfully and populates ``trains`` / + ``neuron_attributes`` for the first few units, then raises one of + the caught exceptions mid-iteration, the loader must reset both + lists before delegating to the h5py path. Otherwise the h5py + ``extend`` lands on top of the partial pynwb state and + ``SpikeData.N`` ends up as ``partial + h5py_count`` instead of + ``h5py_count``. + + Tests: + (Test Case 1) ``sd.N`` matches the h5py-only count (2), not + the buggy ``partial + h5py_count`` (4). + (Test Case 2) ``sd.train`` contents come from the h5py + layout, confirming pynwb partial data was discarded. + """ + pynwb = pytest.importorskip("pynwb") + from types import SimpleNamespace + + class _MidLoopRaisingDataFrame: + """Iterates K rows successfully, then raises ``KeyError``.""" + + columns: list = [] # no candidate channel columns + + def __init__(self, k_good: int): + self.k_good = k_good + + def itertuples(self): + for i in range(self.k_good): + yield SimpleNamespace( + Index=i, + spike_times=np.array([0.001 * (i + 1)]), + ) + raise KeyError("simulated mid-loop pynwb read failure") + + class _FakeIO: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def read(self): + return SimpleNamespace( + units=SimpleNamespace( + to_dataframe=lambda: _MidLoopRaisingDataFrame(k_good=2) + ), + electrodes=None, + ) + + monkeypatch.setattr(pynwb, "NWBHDF5IO", _FakeIO) + + path = str(tmp_path / "test.nwb") + self._write_valid_nwb_h5py_layout(path) # 2 units in h5py layout + + with pytest.warns(UserWarning) as recwarn: + sd = loaders.load_spikedata_from_nwb(path) + self._assert_h5py_fallback_warning(recwarn, KeyError) + + # Without the fix, sd.N would be 4 (2 partial pynwb + 2 h5py). + assert sd.N == 2, ( + f"pynwb partial state leaked into h5py fallback: got N={sd.N}, " + f"expected 2" + ) + np.testing.assert_array_equal(sd.train[0], np.array([100.0, 200.0])) + np.testing.assert_array_equal(sd.train[1], np.array([500.0])) + class TestLoadSpikedataFromHdf5RawThresholdedFsHzValidation: """``load_spikedata_from_hdf5_raw_thresholded`` must validate diff --git a/tests/test_hippie_adapter.py b/tests/test_hippie_adapter.py index 6bffa30d..f77482c4 100644 --- a/tests/test_hippie_adapter.py +++ b/tests/test_hippie_adapter.py @@ -19,15 +19,17 @@ classify_neurons, ) - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- -def _make_spike_train(n_spikes=200, duration_s=60.0, seed=0): +_DURATION_MS = 60_000.0 # 60 s recording, expressed in ms (SpikeLab convention) + + +def _make_spike_train(n_spikes=200, duration_ms=_DURATION_MS, seed=0): rng = np.random.default_rng(seed) - return np.sort(rng.uniform(0, duration_s, n_spikes)) + return np.sort(rng.uniform(0, duration_ms, n_spikes)) def _make_waveform(n=82, seed=0): @@ -40,11 +42,10 @@ def _make_spikedata(n_units=10, seed=0): """Return a minimal SpikeData with avg_waveform in neuron_attributes.""" from spikelab.spikedata import SpikeData - rng = np.random.default_rng(seed) trains = [_make_spike_train(200 + i * 10, seed=seed + i) for i in range(n_units)] waveforms = [_make_waveform(seed=seed + i) for i in range(n_units)] attrs = [{"avg_waveform": w} for w in waveforms] - return SpikeData(trains, length=60.0, neuron_attributes=attrs) + return SpikeData(trains, length=_DURATION_MS, neuron_attributes=attrs) # --------------------------------------------------------------------------- @@ -88,6 +89,16 @@ def test_silent_neuron(self): assert hist.shape == (100,) assert np.isfinite(hist).all() + def test_populated_for_realistic_train(self): + # Regression guard: spike times in ms should land inside the + # [1, 5000] ms log-spaced bin range and produce a histogram with + # genuine structure (more than two distinct values after + # normalization). A unit-conversion bug pushing ISIs out of range + # collapses the histogram to a flat output and trips this check. + st = _make_spike_train(500) + hist = _isi_histogram(st) + assert np.unique(hist).size > 2 + class TestAutocorrelogram: def test_output_shape(self): @@ -106,6 +117,15 @@ def test_empty_train(self): assert acg.shape == (100,) assert (acg == 0).all() + def test_populated_for_realistic_train(self): + # Regression guard: with a high enough rate that some pairs fall + # inside max_lag_ms=100, the ACG must have non-zero counts. A + # unit-conversion bug puts every lag outside the window and + # collapses the ACG to all zeros (or constant). + st = _make_spike_train(2000) + acg = _autocorrelogram(st) + assert np.unique(acg).size > 1 + # --------------------------------------------------------------------------- # extract_features @@ -130,7 +150,7 @@ def test_no_waveform_raises(self): from spikelab.spikedata import SpikeData trains = [_make_spike_train(100, seed=i) for i in range(3)] - sd = SpikeData(trains, length=60.0) + sd = SpikeData(trains, length=_DURATION_MS) with pytest.raises(ValueError, match="avg_waveform"): extract_features(sd) From 15018be138605aa3130f5b40931b5ea9514d6f35 Mon Sep 17 00:00:00 2001 From: TjitsevdM Date: Fri, 29 May 2026 09:40:29 -0700 Subject: [PATCH 7/8] Remove leftover rebase conflict marker in server.py A stray "<<<<<<< HEAD" line was missed during the rebase resolution of the MCP tool dispatch dict, breaking both Python parsing (and so the test suite's import) and the Black formatting check. The conflict content itself was already resolved correctly; only the orphaned opening marker line needed removing. --- src/spikelab/mcp_server/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikelab/mcp_server/server.py b/src/spikelab/mcp_server/server.py index a8750f3d..dbd76055 100644 --- a/src/spikelab/mcp_server/server.py +++ b/src/spikelab/mcp_server/server.py @@ -4379,7 +4379,6 @@ async def _list_tools() -> list[types.Tool]: "load_workspace_item": analysis.load_workspace_item, "merge_workspace": analysis.merge_workspace, "fetch_workspace_item": analysis.fetch_workspace_item, -<<<<<<< HEAD # Shuffling and stack builders "spike_shuffle": analysis.spike_shuffle, "spike_shuffle_stack": analysis.spike_shuffle_stack, From d5acb781fb9b5a986f60d0203f9c43fbd29fad0c Mon Sep 17 00:00:00 2001 From: TjitsevdM Date: Sat, 30 May 2026 00:53:49 -0700 Subject: [PATCH 8/8] Resolve W2 and S4 from PR #120 review W2 (non-zero flat waveforms): remove the `if mx > mn` guard from _preprocess_waveform and add `+ 1e-8` to the divisor so flat inputs (dead channels, clipped recordings) collapse deterministically to all -1.0 instead of passing through as the raw constant. Per Jesus's guidance on the PR thread: dead units cluster together in the noise corner of the latent space rather than scattering on out-of-range values. Adds a regression test asserting a non-zero flat waveform maps to -1.0; the existing all-zeros test still passes. S4 (cross-helper consistency / epsilon intent): with the waveform path now using the same `+ 1e-8` pattern as _isi_histogram and _autocorrelogram, the normalisation idiom is uniform across all three helpers. Add a one-line comment above each `+ 1e-8` site noting that the epsilon mirrors the HIPPIE training-pipeline normalisation, so a future reader doesn't strip it as dead code (confirmed by Jesus on the PR thread). --- src/spikelab/spikedata/hippie_adapter.py | 13 ++++++++++--- tests/test_hippie_adapter.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py index b1d85d99..8790afb2 100644 --- a/src/spikelab/spikedata/hippie_adapter.py +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -48,15 +48,20 @@ def _require_hippie(): def _preprocess_waveform(wave: np.ndarray, target: int = 50) -> np.ndarray: - """Resample waveform to target length and min-max normalize to [-1, 1].""" + """Resample waveform to target length and min-max normalize to [-1, 1]. + + Flat waveforms (dead channels, clipped recordings) collapse to all -1.0 + via the epsilon-protected divisor, putting those units in a deterministic + noise corner of the HIPPIE latent space. + """ import torch import torch.nn.functional as F t = torch.as_tensor(wave, dtype=torch.float32).view(1, 1, -1) t = F.interpolate(t, size=(target,), mode="linear", align_corners=False).squeeze() mn, mx = t.min().item(), t.max().item() - if mx > mn: - t = (t - mn) / (mx - mn) * 2.0 - 1.0 + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). + t = (t - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 return t.numpy().astype(np.float32) @@ -80,6 +85,7 @@ def _isi_histogram(spike_times: np.ndarray, n_bins: int = _ISI_N_BINS) -> np.nda hist = np.log1p(hist) mn, mx = hist.min(), hist.max() if mx > mn: + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). hist = (hist - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 return hist @@ -114,6 +120,7 @@ def _autocorrelogram( acg = counts.astype(np.float32) mn, mx = acg.min(), acg.max() if mx > mn: + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). acg = (acg - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 return acg diff --git a/tests/test_hippie_adapter.py b/tests/test_hippie_adapter.py index f77482c4..a1bb98f4 100644 --- a/tests/test_hippie_adapter.py +++ b/tests/test_hippie_adapter.py @@ -71,6 +71,17 @@ def test_flat_waveform_does_not_crash(self): assert out.shape == (50,) assert np.isfinite(out).all() + def test_nonzero_flat_waveform_collapses_to_minus_one(self): + # Dead channel / clipped recording: a constant non-zero waveform + # should normalize to all -1.0 (not pass through as the raw constant) + # so those units land in a deterministic noise corner of the latent + # space and cluster together rather than scattering on out-of-range + # values. Regression guard for the W2 fix discussed in PR #120. + wave = np.full(50, -50.0) + out = _preprocess_waveform(wave) + assert out.shape == (50,) + np.testing.assert_allclose(out, -1.0, atol=1e-5) + class TestISIHistogram: def test_output_shape(self):