diff --git a/docs/source/guides/hippie_agent_prompts.md b/docs/source/guides/hippie_agent_prompts.md new file mode 100644 index 00000000..1d954bd7 --- /dev/null +++ b/docs/source/guides/hippie_agent_prompts.md @@ -0,0 +1,124 @@ +# HIPPIE Agent Prompts for SpikeLab + +Personal reference for running HIPPIE cell-type classification and VAE compression +via an AI agent (e.g., Claude Code) that has access to the SpikeLab MCP server. + +--- + +## 1. Pretrained HIPPIE Classifier + +### Pipeline A — Kilosort output + raw .bin file + +``` +I have Kilosort spike-sorting results at /path/to/kilosort_output/ and the +raw binary recording at /path/to/recording.bin with 384 channels sampled at +30,000 Hz. The electrode geometry is neuropixels. + +Using the SpikeLab MCP server: +1. Load the Kilosort output into a new workspace called "my_session". +2. Extract average waveforms from the raw binary for every unit + (window: 1 ms before to 2 ms after each spike, max 500 spikes per unit). +3. Classify neurons with the pretrained HIPPIE model + (tech_id="neuropixels", run_umap=True, run_hdbscan=True, min_cluster_size=5). +4. Return the cluster label counts and save embeddings + UMAP coordinates + back into the workspace under the "hippie" namespace. +``` + +### Pipeline B — NWB file + +``` +I have an NWB file at /path/to/session.nwb that contains spike times and +pre-sorted units with waveforms stored under processing/ecephys. + +Using the SpikeLab MCP server: +1. Load the NWB file into a workspace called "nwb_session". +2. Classify neurons with the pretrained HIPPIE model + (tech_id="neuropixels", run_umap=True, run_hdbscan=True). +3. Add hippie_cluster, hippie_embedding, hippie_umap_x, hippie_umap_y as + neuron attributes back to the workspace. +4. Show me the cluster label distribution. +``` + +### Pipeline C — Pre-computed waveforms already in SpikeData + +``` +I already have a SpikeData object saved at /path/to/session.pkl that includes +avg_waveform in neuron_attributes. + +Using the SpikeLab MCP server: +1. Load the pickle into workspace "precomputed". +2. Run classify_neurons_hippie with tech_id="neuropixels" and default UMAP/HDBSCAN settings. +3. Write results back to the workspace and show me the number of neurons per cluster. +``` + +--- + +## 2. Unsupervised VAE (train on your own data, no conditioning) + +### Train a new VAE on your recordings + +``` +I have a SpikeData object at workspace "my_session" namespace "sorted" with +avg_waveform in neuron_attributes. + +Using the SpikeLab MCP server: +1. Train an unsupervised multimodal VAE on the neurons in that workspace + using train_vae_hippie: + - output_dir: ./vae_checkpoints/my_session + - z_dim: 30 + - n_epochs: 100 + - batch_size: 256 + - val_fraction: 0.1 +2. Tell me the best validation loss achieved and the path to the saved checkpoint. +``` + +### Compress neurons with a trained checkpoint + +``` +I have a trained VAE checkpoint at ./vae_checkpoints/my_session/vae_best.ckpt +and a SpikeData object in workspace "new_session" namespace "sorted". + +Using the SpikeLab MCP server: +1. Run compress_neurons_hippie with that checkpoint on the workspace. +2. Use run_umap=True and run_hdbscan=True with min_cluster_size=5. +3. Write vae_embedding, vae_umap_x, vae_umap_y, vae_cluster into the workspace. +4. Return the cluster label counts. +``` + +--- + +## 3. Run Both Models and Compare + +``` +I have a SpikeData workspace "comparison_session" namespace "sorted" with +avg_waveform in neuron_attributes (neuropixels recording). + +Using the SpikeLab MCP server, run both classification pipelines and compare: + +Step 1 — Pretrained HIPPIE: + classify_neurons_hippie(tech_id="neuropixels", run_umap=True, run_hdbscan=True, + min_cluster_size=5) + +Step 2 — Unsupervised VAE (train from scratch): + train_vae_hippie(output_dir="./vae_ckpt", z_dim=30, n_epochs=100) + compress_neurons_hippie(checkpoint_path="./vae_ckpt/vae_best.ckpt", + run_umap=True, run_hdbscan=True, min_cluster_size=5) + +Step 3 — Report: + - Number of HIPPIE clusters vs VAE clusters + - Number of noise-labeled neurons (-1) in each + - Suggest which to use for downstream analysis given the cluster coherence +``` + +--- + +## 4. Key Notes + +| Item | Detail | +|------|--------| +| `tech_id` options | `"neuropixels"` (0), `"silicon_probe"` (1), `"juxtacellular"` (2), `"tetrodes"` (3) | +| HIPPIE checkpoint | Auto-downloaded from `Jesusgf23/hippie` on HuggingFace (~293 MB, cached after first use) | +| VAE checkpoint | Saved locally to `output_dir/vae_best.ckpt` | +| `avg_waveform` required | Must be present in `neuron_attributes` before calling either pipeline | +| MCP per-unit limitation | `get_waveform_traces` MCP tool is per-unit; for bulk waveform extraction ask the agent to loop over units or use the Python API directly | +| HIPPIE install | `pip install spikelab[hippie]` before running any of the above | diff --git a/docs/source/guides/hippie_changes.md b/docs/source/guides/hippie_changes.md new file mode 100644 index 00000000..fac102c4 --- /dev/null +++ b/docs/source/guides/hippie_changes.md @@ -0,0 +1,375 @@ +# HIPPIE Integration — Codebase Changes + +## Overview + +HIPPIE is added as an **optional dependency** of SpikeLab. +Users who want cell-type classification install it with: + +```bash +pip install "spikelab[hippie]" +``` + +Nothing in the base SpikeLab install changes — all new code is either in new files or behind lazy imports that only run when HIPPIE is actually present. + +--- + +## Repository map + +``` +HIPPIE/ +└── hippie/ + ├── __init__.py ← MODIFIED (+2 lines) + ├── inference.py ← NEW + └── checkpoint.py ← NEW + +SpikeLab/ +├── pyproject.toml ← MODIFIED (+7 lines) +├── src/spikelab/ +│ ├── spikedata/ +│ │ └── hippie_adapter.py ← NEW +│ └── mcp_server/ +│ ├── tools/ +│ │ └── analysis.py ← MODIFIED (+90 lines at EOF) +│ └── server.py ← MODIFIED (+60 lines: schema + dispatch) +├── tests/ +│ └── test_hippie_adapter.py ← NEW +└── docs/source/guides/ + ├── index.rst ← MODIFIED (+1 line) + └── hippie_classification.rst ← NEW +``` + +--- + +## File-by-file changes + +### `HIPPIE/hippie/__init__.py` — modified + +**What changed:** Added one export line. + +```python +# Before (4 lines) +from .multimodal_model import MultiModalCVAE, ... +from .dataloading import ... +from .augmentations import ... +from .backbones import ... + +# After (+1 line) +from .inference import HIPPIEClassifier, TECHNOLOGY_IDS +``` + +**Why:** Allows `from hippie import HIPPIEClassifier` — the clean public API. + +--- + +### `HIPPIE/hippie/checkpoint.py` — new file + +**Purpose:** Load a pretrained checkpoint and return a ready `MultiModalCVAE`. + +``` +infer_model_dims(state_dict) + → reads source_embed.weight / class_embed.weight shapes + → returns (num_sources, num_classes) + +build_model(ckpt_path) + → torch.load(..., weights_only=False) + → strips Lightning "model." prefix from state_dict keys + → calls infer_model_dims to auto-detect architecture + → instantiates MultiModalCVAE with hardcoded inference config: + modalities = {"wave": 50, "isi": 100, "acg": 100} + z_dim = 30 + config = ExperimentConfigs.class_decoder_source_bn_aug_reg() + backbone_base_width = 64 + → model.load_state_dict(sd, strict=False) + → returns model in eval() mode on CPU +``` + +**Key detail:** `strict=False` on `load_state_dict` — the checkpoint may contain +decoder weights that are not needed for inference; this suppresses the error. + +--- + +### `HIPPIE/hippie/inference.py` — new file + +**Purpose:** High-level inference API used by both direct callers and the SpikeLab adapter. + +``` +TECHNOLOGY_IDS: dict + {"neuropixels": 0, "silicon_probe": 1, "juxtacellular": 2, "tetrodes": 3} + +class HIPPIEClassifier: + + from_pretrained(repo_id, filename, device, cache_dir) + → hf_hub_download(repo_id, filename) + → build_model(ckpt_path) + → returns HIPPIEClassifier + + from_checkpoint(checkpoint_path, device) + → build_model(str(path)) + → returns HIPPIEClassifier + + get_embeddings(wave, isi, acg, tech_id, batch_size) → np.ndarray (N, 30) + → accepts tech_id as int or name string + → batches input; calls model.encode({wave, isi, acg}, source_labels=tech_id) + → returns z_mean concatenated across batches + + umap_reduce(embeddings, n_components, n_neighbors, min_dist, metric, random_state) + → static method; lazy-imports umap.UMAP + → returns (N, n_components) float32 coords + + hdbscan_cluster(embeddings, min_cluster_size, min_samples, metric) + → static method; lazy-imports hdbscan.HDBSCAN + → returns (N,) int32 labels; -1 = noise +``` + +**Input contract for `get_embeddings`:** + +| Modality | Shape | Normalization | +|----------|---------|------------------------------------------| +| `wave` | (N, 50) | min-max → [-1, 1] | +| `isi` | (N, 100)| log(x+1), then min-max → [-1, 1] | +| `acg` | (N, 100)| min-max → [-1, 1] | + +--- + +### `SpikeLab/pyproject.toml` — modified + +**What changed:** Added the `hippie` optional extra. + +```toml +[project.optional-dependencies] +# ... existing extras unchanged ... +hippie = [ + "hippie @ git+https://github.com/braingeneers/HIPPIE.git", + "huggingface-hub>=0.20", + "torch>=2.0", + "umap-learn>=0.5.0", + "hdbscan>=0.8", +] +``` + +**Note:** PyTorch with CUDA must still be installed separately if GPU inference is desired. + +--- + +### `SpikeLab/src/spikelab/spikedata/hippie_adapter.py` — new file + +**Purpose:** Bridge between `SpikeData` and the HIPPIE encoder. +Not exported from `spikedata/__init__.py` — imported explicitly to avoid import errors for users without HIPPIE. + +``` +# Guards +_require_hippie() + → tries `import hippie`; raises ImportError with install hint if missing + +# Preprocessing helpers (pure numpy/torch, no HIPPIE dependency) +_preprocess_waveform(wave, target=50) → (50,) float32 + → F.interpolate to target length, min-max → [-1, 1] + +_isi_histogram(spike_times, n_bins=100) → (100,) float32 + → np.diff(sorted spikes) * 1000 → ms + → np.histogram with log-spaced bins [1 ms, 5000 ms] + → log1p transform, then min-max → [-1, 1] + → silent neurons (< 2 spikes) return flat -1 array + +_autocorrelogram(spike_times, max_lag_ms=100, n_bins=100) → (100,) float32 + → forward-looking searchsorted loop (O(N) per lag window) + → normalised to sum=1, then min-max → [-1, 1] + → empty trains return zeros + +# Public API +extract_features(sd, isi_bins, acg_bins, acg_max_lag_ms) → dict + → reads avg_waveform from neuron_attributes (raises ValueError if missing) + → stacks waveforms, ISI histograms, ACGs into (N, bins) arrays + → returns {"wave": (N,50), "isi": (N,100), "acg": (N,100)} + +classify_neurons(sd, repo_id, tech_id, device, run_umap, run_hdbscan, + umap_kwargs, hdbscan_kwargs, batch_size, cache_dir) → dict + → calls _require_hippie() then imports HIPPIEClassifier inside function + → calls extract_features(sd) + → HIPPIEClassifier.from_pretrained(repo_id, device, cache_dir) + → clf.get_embeddings(wave, isi, acg, tech_id, batch_size) + → optionally: clf.umap_reduce(embeddings, **umap_kwargs) + → optionally: clf.hdbscan_cluster(umap_coords or embeddings, **hdbscan_kwargs) + → returns {"embeddings", "umap_coords"?, "cluster_labels"?} +``` + +--- + +### `SpikeLab/src/spikelab/mcp_server/tools/analysis.py` — modified + +**What changed:** ~90 lines appended at end of file. + +``` +async def classify_neurons_hippie( + workspace_id, namespace, + tech_id=0, run_umap=True, run_hdbscan=True, + min_cluster_size=5, umap_n_neighbors=30, umap_min_dist=0.1, + device="cpu", cache_dir=None +) → dict + + Calls: + _get_workspace(workspace_id) + _get_spikedata(ws, namespace) + classify_neurons(sd, ...) ← imported lazily inside function + sd.set_neuron_attribute(...) ← writes hippie_embedding, + hippie_umap_x/y, hippie_cluster + ws.store(namespace, "spikedata", sd) + + Returns JSON summary: + n_neurons, embedding_dim, + umap_computed, hdbscan_computed, + n_clusters, n_noise_neurons, + neuron_attributes_added +``` + +**Import note:** `from ....spikedata.hippie_adapter import classify_neurons` is +inside the function body — the MCP server starts fine without HIPPIE installed +and only fails at call time with a clear error message. + +--- + +### `SpikeLab/src/spikelab/mcp_server/server.py` — modified + +**Two edits:** + +1. **Tool schema** — inserted before the "Workspace management tools" section: + +```python +types.Tool( + name="classify_neurons_hippie", + description="Classify neurons using the pretrained HIPPIE multimodal model ...", + inputSchema={ + "required": ["workspace_id", "namespace"], + "properties": { + workspace_id, namespace, + tech_id (int, default 0), + run_umap (bool, default True), + run_hdbscan (bool, default True), + min_cluster_size (int, default 5), + umap_n_neighbors (int, default 30), + umap_min_dist (float, default 0.1), + device (str, default "cpu"), + cache_dir (str, optional), + } + } +) +``` + +2. **Dispatch entry** — added to `_TOOL_DISPATCH`: + +```python +"classify_neurons_hippie": analysis.classify_neurons_hippie, +``` + +--- + +### `SpikeLab/tests/test_hippie_adapter.py` — new file + +**13 tests across 5 classes**, all skipped automatically if `hippie` is not installed (`pytest.importorskip`). + +``` +TestPreprocessWaveform (3 tests) + ✓ output shape is (50,) + ✓ values in [-1, 1] + ✓ flat waveform (all zeros) does not crash + +TestISIHistogram (3 tests) + ✓ output shape is (100,) + ✓ values in [-1, 1] + ✓ silent neuron (1 spike) does not crash + +TestAutocorrelogram (3 tests) + ✓ output shape is (100,) + ✓ values in [-1, 1] + ✓ empty spike train returns zeros + +TestExtractFeatures (3 tests) + ✓ shapes (N, 50) / (N, 100) / (N, 100) + ✓ dtype is float32 + ✓ missing avg_waveform raises ValueError + +TestClassifyNeurons (4 tests — fully mocked, no download) + @patch("hippie.inference.HIPPIEClassifier") ← correct patch target + ✓ returns embeddings + umap_coords + cluster_labels by default + ✓ run_umap=False, run_hdbscan=False omits those keys + ✓ embedding shape is (N, 30) + ✓ tech_id string is forwarded unchanged to get_embeddings +``` + +--- + +### `SpikeLab/docs/source/guides/hippie_classification.rst` — new file + +Full Sphinx guide covering: +- Installation +- Quick start +- Return value table +- Technology ID table +- Advanced options (custom UMAP/HDBSCAN params, embeddings only, direct HIPPIE API) +- MCP server usage + example agent prompts +- How it works (feature extraction → encoding → UMAP → HDBSCAN) +- Checkpoint info + +### `SpikeLab/docs/source/guides/index.rst` — modified + +Added `hippie_classification` to the toctree (+1 line). + +--- + +## Data flow + +``` +SpikeData + │ + │ avg_waveform (neuron_attributes) + │ spike trains (sd.train) + ▼ +hippie_adapter.extract_features() + │ + ├── _preprocess_waveform() → (N, 50) min-max [-1,1] + ├── _isi_histogram() → (N, 100) log1p + min-max [-1,1] + └── _autocorrelogram() → (N, 100) normalized + min-max [-1,1] + │ + ▼ +HIPPIEClassifier.get_embeddings() + │ + │ model.encode({wave, isi, acg}, source_labels=tech_id) + │ ↳ ResNet18Enc × 3 → fusion_encoder → z_mean (30-D) + │ + ▼ + (N, 30) embeddings + │ + ├──[run_umap=True]──► umap_reduce() → (N, 2) coords + │ cosine metric, n_neighbors=30 + │ + └──[run_hdbscan=True]► hdbscan_cluster() on umap_coords + → (N,) labels (-1 = noise) + │ + ▼ +classify_neurons() returns + {"embeddings": (N,30), "umap_coords": (N,2), "cluster_labels": (N,)} + │ + ▼ [via MCP tool classify_neurons_hippie] +neuron_attributes: + hippie_embedding (N, 30) + hippie_umap_x (N,) + hippie_umap_y (N,) + hippie_cluster (N,) -1 = noise +``` + +--- + +## Bug fixed during audit + +| File | Issue | Fix | +|------|-------|-----| +| `tests/test_hippie_adapter.py` | `@patch("spikelab.spikedata.hippie_adapter.HIPPIEClassifier")` — wrong target; `HIPPIEClassifier` is imported *inside* `classify_neurons` via `from hippie.inference import HIPPIEClassifier`, so the mock must be placed on the source module | Changed to `@patch("hippie.inference.HIPPIEClassifier")` | + +--- + +## What was NOT changed + +- `spikedata/__init__.py` — `hippie_adapter` is intentionally **not** re-exported here; importing it would crash base installs without HIPPIE +- Any existing analysis, data loader, or exporter tool +- HIPPIE training code (`multimodal_model.py`, `dataloading.py`, `augmentations.py`, etc.) +- The HIPPIE `pyproject.toml` diff --git a/docs/source/guides/hippie_classification.rst b/docs/source/guides/hippie_classification.rst new file mode 100644 index 00000000..a7bbe0f7 --- /dev/null +++ b/docs/source/guides/hippie_classification.rst @@ -0,0 +1,495 @@ +.. _hippie_classification: + +Cell-Type Classification with HIPPIE +===================================== + +SpikeLab has optional integration with `HIPPIE`_, a pretrained multimodal +generative model for neuron classification. HIPPIE encodes each neuron's +**waveform**, **interspike-interval distribution**, and **autocorrelogram** +into a shared 30-D latent space, then uses UMAP + HDBSCAN for unsupervised +cell-type discovery. + +.. _HIPPIE: https://huggingface.co/Jesusgf23/hippie + +Installation +------------ + +HIPPIE is an optional dependency — install it alongside SpikeLab:: + + pip install "spikelab[hippie]" + +This pulls in PyTorch, HuggingFace Hub, umap-learn, and hdbscan in addition +to the HIPPIE package itself. PyTorch with CUDA must be installed separately +if GPU inference is desired. + +.. note:: + + Nothing in the base SpikeLab install is affected. The HIPPIE adapter is + never imported unless you explicitly call it. + +Data requirements +----------------- + +HIPPIE requires three features per neuron: + +* **Average waveform** — stored as ``avg_waveform`` in ``neuron_attributes`` +* **Spike trains** — always present in a ``SpikeData`` object +* **Recording technology** — passed as ``tech_id`` at call time + +The waveform is the only thing that may need preparation. The three +pipelines below cover the most common starting points. + +.. _hippie_pipeline_a: + +Pipeline A — Kilosort output + raw ``.bin`` file +------------------------------------------------- + +This is the typical Neuropixels + Kilosort4 workflow. Kilosort gives you +spike times only; waveforms are extracted from the raw voltage trace +afterward. + +.. note:: + + Attaching raw data to a ``SpikeData`` object is currently a Python-only + step — there is no MCP tool for it. Use this path when scripting directly. + +.. code-block:: python + + import numpy as np + from spikelab.data_loaders import load_spikedata_from_kilosort + from spikelab.spikedata.hippie_adapter import classify_neurons + + # 1. Load spike times from Kilosort output directory + sd = load_spikedata_from_kilosort( + folder_path="/path/to/kilosort_output/", + fs_Hz=30000, # Neuropixels default + cluster_info_tsv="cluster_info.tsv", + include_noise=False, + ) + + # 2. Attach the raw voltage recording + # Shape must be (n_channels, n_samples). + # Use np.memmap for large files to avoid loading everything into RAM. + raw = np.memmap( + "/path/to/recording.ap.bin", + dtype="int16", + mode="r", + shape=(385, n_samples), # adjust n_channels and n_samples + ) + sd.raw_data = raw.astype(np.float32) + sd.raw_time = 30.0 # sampling rate in kHz (30 000 Hz) + + # 3. Extract average waveforms for all units in one call. + # store=True writes avg_waveform into neuron_attributes automatically. + sd.get_waveform_traces( + unit=None, # None = all units + ms_before=1.0, + ms_after=2.0, + store=True, + ) + + # 4. Run HIPPIE: embed → UMAP → HDBSCAN + result = classify_neurons( + sd, + tech_id="neuropixels", # or tech_id=0 + run_umap=True, + run_hdbscan=True, + hdbscan_kwargs={"min_cluster_size": 5}, + ) + + # 5. Store results back into neuron_attributes + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + + n_clusters = (result["cluster_labels"] >= 0).sum() + print(f"{sd.N} neurons → {n_clusters} clustered, " + f"{(result['cluster_labels'] < 0).sum()} noise") + +What ``get_waveform_traces`` does in step 3 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For every unit it finds the peak channel (from ``neuron_to_channel_map``), +extracts a short voltage snippet around each spike, averages them, and +stores the average as ``neuron_attributes[i]["avg_waveform"]``. The adapter +then reads those stored values — no raw data is needed after this point. + + +.. _hippie_pipeline_b: + +Pipeline B — NWB file with raw traces +-------------------------------------- + +NWB files produced by SpikeInterface, the Allen Brain Atlas pipeline, or +similar tools often embed both spike times and raw traces in a single file. +This is the only path that works end-to-end from the **MCP / agent interface**. + +Python +~~~~~~ + +.. code-block:: python + + from spikelab.data_loaders import load_spikedata_from_nwb + from spikelab.spikedata.hippie_adapter import classify_neurons + + sd = load_spikedata_from_nwb("/path/to/recording.nwb") + + # Extract waveforms for all units + sd.get_waveform_traces(unit=None, ms_before=1.0, ms_after=2.0, store=True) + + result = classify_neurons(sd, tech_id="neuropixels") + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +MCP / agent +~~~~~~~~~~~ + +Give an agent these prompts in order: + +.. code-block:: text + + 1. "Load the NWB file at /path/to/recording.nwb" + + 2. "Extract waveforms for all N units with 1 ms before and 2 ms after the spike" + (the agent will call get_waveform_traces once per unit) + + 3. "Classify the neurons using HIPPIE with tech_id 0 (neuropixels)" + + 4. "How many clusters did HIPPIE find? List cluster IDs and neuron counts." + +.. note:: + + The MCP ``get_waveform_traces`` tool extracts one unit at a time. + For a recording with many units the agent needs to call it N times before + HIPPIE can run. See :ref:`hippie_mcp_gap` below. + + +.. _hippie_pipeline_c: + +Pipeline C — Waveforms already available +----------------------------------------- + +If ``avg_waveform`` is already in ``neuron_attributes`` — e.g. loaded from an +HDF5 workspace, set manually from an upstream pipeline, or computed in a +previous session — skip straight to classification: + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import classify_neurons + + # sd already has avg_waveform in neuron_attributes + result = classify_neurons(sd, tech_id="neuropixels") + + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +To check whether waveforms are already present before trying: + +.. code-block:: python + + waves = sd.get_neuron_attribute("avg_waveform") + if waves is None or any(w is None for w in waves): + print("Waveforms missing — run get_waveform_traces first") + else: + print(f"Waveforms ready for {sd.N} units") + +Quick start (waveforms already present) +---------------------------------------- + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import classify_neurons + + result = classify_neurons( + sd, + tech_id="neuropixels", # or 0, 1, 2, 3 — see Technology IDs below + run_umap=True, + run_hdbscan=True, + ) + + sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("hippie_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("hippie_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + +Return values +~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Key + - Shape + - Description + * - ``embeddings`` + - ``(N, 30)`` + - Latent z_mean vectors from the HIPPIE encoder + * - ``umap_coords`` + - ``(N, 2)`` + - 2-D UMAP projection (present when ``run_umap=True``) + * - ``cluster_labels`` + - ``(N,)`` + - HDBSCAN cluster IDs; ``-1`` = noise / unclustered + (present when ``run_hdbscan=True``) + +Technology IDs +~~~~~~~~~~~~~~ + +The pretrained checkpoint was trained on recordings from four technology +families. Pass the matching ``tech_id`` for best results: + +.. list-table:: + :header-rows: 1 + :widths: 10 25 + + * - ``tech_id`` + - Technology + * - ``0`` / ``"neuropixels"`` + - Neuropixels probes *(default)* + * - ``1`` / ``"silicon_probe"`` + - Silicon probes (non-Neuropixels) + * - ``2`` / ``"juxtacellular"`` + - Juxtacellular recordings + * - ``3`` / ``"tetrodes"`` + - Tetrode recordings + +Unsupervised VAE compression (no conditioning) +---------------------------------------------- + +If you do not have cell-type or technology labels, or simply want to learn +a compressed representation of *your own dataset* from scratch, the +unconditioned VAE pipeline trains the same ResNet18 + fusion encoder +architecture as the pretrained model but with all conditioning removed. +The only training signal is reconstruction + KL divergence (beta-VAE ELBO, +beta=1) — no class embeddings, no technology embeddings. + +Results are stored as ``vae_embedding``, ``vae_umap_x``, ``vae_umap_y``, +and ``vae_cluster`` in ``neuron_attributes``, keeping them separate from the +pretrained-model attributes (``hippie_embedding`` etc.). + +Train and compress in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import ( + train_vae_on_spikedata, + compress_neurons, + ) + + # Step 1 — train on your data (requires avg_waveform in neuron_attributes) + compressor = train_vae_on_spikedata( + sd, + output_dir="./my_vae", + z_dim=30, # latent dimensionality + n_epochs=100, + batch_size=256, + device="cpu", # or "cuda" + ) + # Checkpoint saved to ./my_vae/vae_best.ckpt + + # Step 2 — compress (can reuse the returned compressor, or reload later) + result = compress_neurons(sd, compressor, run_umap=True, run_hdbscan=True) + + sd.set_neuron_attribute("vae_cluster", result["cluster_labels"]) + sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0]) + sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1]) + sd.set_neuron_attribute("vae_embedding", result["embeddings"]) + +Reload and compress new data later +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from spikelab.spikedata.hippie_adapter import compress_neurons + + # Load a previously trained checkpoint by path + result = compress_neurons(sd_new, "./my_vae/vae_best.ckpt") + +Use the VAE API directly +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from hippie.vae import train_vae, VAECompressor + + # Train + compressor = train_vae(wave, isi, acg, output_dir="./my_vae", z_dim=30, n_epochs=100) + + # Or load an existing checkpoint + compressor = VAECompressor.from_checkpoint("./my_vae/vae_best.ckpt") + + # Encode + embeddings = compressor.get_embeddings(wave, isi, acg) # (N, z_dim) + coords = compressor.umap_reduce(embeddings) + labels = compressor.hdbscan_cluster(coords, min_cluster_size=5) + +MCP / agent +~~~~~~~~~~~ + +.. code-block:: text + + 1. "Train an unconditioned VAE on the neurons in namespace 'probe0', + saving to ./my_vae, with 50 epochs." + → calls train_vae_hippie + + 2. "Compress the neurons in 'probe0' using the checkpoint at ./my_vae/vae_best.ckpt." + → calls compress_neurons_hippie + + 3. "How many clusters did the VAE find?" + +How it differs from the pretrained classifier +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 35 35 + + * - + - Pretrained HIPPIE (``classify_neurons``) + - Unconditioned VAE (``compress_neurons``) + * - Requires labels + - No (inference only) + - No (train & infer) + * - Requires ``tech_id`` + - Yes + - No + * - Trains on your data + - No + - Yes + * - Learns from 11 datasets + - Yes (pretrained) + - No (your data only) + * - Latent space shaped by + - Cell types + technology + - Reconstruction only + * - Best for + - Known-technology recordings + - Exploratory compression, novel datasets + +Advanced options +---------------- + +Tuning UMAP and HDBSCAN +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + result = classify_neurons( + sd, + tech_id=0, + umap_kwargs={"n_neighbors": 15, "min_dist": 0.05}, + hdbscan_kwargs={"min_cluster_size": 10, "min_samples": 5}, + ) + +Embeddings only (no clustering) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Useful when you want to inspect the latent space before deciding on +clustering parameters: + +.. code-block:: python + + result = classify_neurons(sd, run_umap=False, run_hdbscan=False) + embeddings = result["embeddings"] # (N, 30) + +Using the HIPPIE API directly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For full control over preprocessing or batching, import +:class:`hippie.inference.HIPPIEClassifier` directly: + +.. code-block:: python + + from hippie import HIPPIEClassifier + + clf = HIPPIEClassifier.from_pretrained("Jesusgf23/hippie", device="cpu") + + # Inputs must be preprocessed — see hippie_adapter.extract_features() + # for the exact normalization applied to each modality. + embeddings = clf.get_embeddings(wave, isi, acg, tech_id=0) + coords = clf.umap_reduce(embeddings, n_neighbors=30) + labels = clf.hdbscan_cluster(coords, min_cluster_size=5) + + # Load from a local checkpoint instead of HuggingFace + clf2 = HIPPIEClassifier.from_checkpoint("./my_trained_model.ckpt") + +Using via the MCP server +------------------------ + +The ``classify_neurons_hippie`` tool is available in the SpikeLab MCP +server once ``spikelab[hippie]`` is installed. After the tool runs, it +writes ``hippie_embedding``, ``hippie_umap_x``, ``hippie_umap_y``, and +``hippie_cluster`` directly into ``neuron_attributes``, making them +accessible to all downstream tools. + +Example agent prompts +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: text + + "Classify the neurons in this recording using HIPPIE." + + "Run HIPPIE cell-type classification with tech_id 1 (silicon probe)." + + "Embed the neurons with HIPPIE and cluster with HDBSCAN, minimum cluster size 10." + + "What cell types did HIPPIE find? List the cluster IDs and neuron counts." + + "Plot the HIPPIE UMAP coloured by cluster label." + +.. _hippie_mcp_gap: + +Known limitation: MCP waveform extraction is per-unit +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The current ``get_waveform_traces`` MCP tool extracts waveforms for a +**single unit** per call. For a recording with N neurons, an agent must +call it N times before ``classify_neurons_hippie`` can run. + +Workaround until a bulk ``extract_all_waveforms`` tool is added: + +* Use Pipeline A or C in Python, where ``get_waveform_traces(unit=None)`` + processes all units in one call. +* Or pre-compute waveforms in Python and save the workspace; the agent can + then load it and run ``classify_neurons_hippie`` directly. + +How it works +------------ + +1. **Feature extraction** — For each neuron, SpikeLab computes: + + * *Waveform* (50 samples, min-max normalized to [-1, 1]) + * *ISI histogram* (100 log-spaced bins from 1 ms to 5 s, log(x+1) + transformed, then min-max normalized) + * *Autocorrelogram* (100 bins, 0–100 ms, min-max normalized) + +2. **Encoding** — Three modality-specific ResNet18 encoders project each + neuron's features into a shared 30-D latent space, conditioned on the + recording technology (``tech_id``). + +3. **UMAP** — The 30-D embeddings are projected to 2-D using cosine-distance + UMAP for visualization and clustering. + +4. **HDBSCAN** — Density clusters are found in the 2-D UMAP space. + Neurons that do not belong to any cluster receive label ``-1``. + +Checkpoint +---------- + +The pretrained model (``hippie_techcond_v1.ckpt``) is hosted at +`huggingface.co/Jesusgf23/hippie`_. It is downloaded automatically on +first use and cached locally (HuggingFace default cache, or override with +``cache_dir``). The file is 293 MB; subsequent calls use the local cache. + +.. _huggingface.co/Jesusgf23/hippie: https://huggingface.co/Jesusgf23/hippie + +The model was pretrained on 11 labeled electrophysiology datasets spanning +mouse, rat, and macaque across multiple brain regions and recording +technologies. diff --git a/docs/source/guides/index.rst b/docs/source/guides/index.rst index 9ea4cb90..c1f18635 100644 --- a/docs/source/guides/index.rst +++ b/docs/source/guides/index.rst @@ -20,3 +20,4 @@ In-depth guides covering the main workflows in SpikeLab. mcp_server parallel_computing batch_jobs + hippie_classification diff --git a/pyproject.toml b/pyproject.toml index c974e86e..998bff3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,14 @@ ibl = [ # pip install git+https://github.com/int-brain-lab/paper-brain-wide-map.git "ONE-api>=2.0", ] +hippie = [ + # Install HIPPIE from GitHub (or a local editable install for development) + "hippie @ git+https://github.com/braingeneers/HIPPIE.git", + "huggingface-hub>=0.20", + "torch>=2.0", + "umap-learn>=0.5.0", + "hdbscan>=0.8", +] all = [ "pandas>=1.3", "mcp>=0.9.0", diff --git a/src/spikelab/data_loaders/data_loaders.py b/src/spikelab/data_loaders/data_loaders.py index e4835512..e9b50042 100644 --- a/src/spikelab/data_loaders/data_loaders.py +++ b/src/spikelab/data_loaders/data_loaders.py @@ -752,6 +752,8 @@ def load_spikedata_from_nwb( f"format or report a bug.", stacklevel=2, ) + trains.clear() + neuron_attributes.clear() ensure_h5py() with h5py.File(filepath, "r") as f: # type: ignore diff --git a/src/spikelab/mcp_server/server.py b/src/spikelab/mcp_server/server.py index 16e78f26..dbd76055 100644 --- a/src/spikelab/mcp_server/server.py +++ b/src/spikelab/mcp_server/server.py @@ -3007,6 +3007,193 @@ async def _list_tools() -> list[types.Tool]: ] ) + # ----------------------------------------------------------------------- + # HIPPIE cell-type classification (optional — requires spikelab[hippie]) + # ----------------------------------------------------------------------- + tools.extend( + [ + types.Tool( + name="classify_neurons_hippie", + description=( + "Classify neurons using the pretrained HIPPIE multimodal model " + "(requires spikelab[hippie]). " + "Downloads the checkpoint from HuggingFace (Jesusgf23/hippie), " + "encodes each neuron's waveform + ISI + autocorrelogram into a " + "30-D latent space, optionally runs UMAP (2-D projection) and " + "HDBSCAN clustering, then stores the results as neuron_attributes " + "(hippie_embedding, hippie_umap_x, hippie_umap_y, hippie_cluster). " + "Requires avg_waveform in neuron_attributes — run " + "get_waveform_traces first if raw data is available." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "tech_id": { + "type": "integer", + "description": ( + "Recording technology index: " + "0=neuropixels (default), 1=silicon_probe, " + "2=juxtacellular, 3=tetrodes" + ), + "default": 0, + }, + "run_umap": { + "type": "boolean", + "description": "Compute 2-D UMAP projection of embeddings (default: true)", + "default": True, + }, + "run_hdbscan": { + "type": "boolean", + "description": "Cluster with HDBSCAN on UMAP coords (default: true)", + "default": True, + }, + "min_cluster_size": { + "type": "integer", + "description": "Minimum neurons per HDBSCAN cluster (default: 5)", + "default": 5, + }, + "umap_n_neighbors": { + "type": "integer", + "description": "UMAP neighbourhood size (default: 30)", + "default": 30, + }, + "umap_min_dist": { + "type": "number", + "description": "UMAP minimum distance between points (default: 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device for the HIPPIE encoder: 'cpu' or 'cuda'", + "default": "cpu", + }, + "cache_dir": { + "type": "string", + "description": "Local directory to cache the downloaded checkpoint (optional)", + }, + }, + "required": ["workspace_id", "namespace"], + }, + ), + ] + ) + + # ----------------------------------------------------------------------- + # Unconditioned VAE: training + compression (requires spikelab[hippie]) + # ----------------------------------------------------------------------- + tools.extend( + [ + types.Tool( + name="train_vae_hippie", + description=( + "Train an unconditioned multimodal VAE on a SpikeData object " + "(requires spikelab[hippie]). " + "Uses the same ResNet18 + fusion encoder as the pretrained HIPPIE " + "model but removes all class and technology conditioning — the VAE " + "learns to compress waveform + ISI + autocorrelogram into a latent " + "space using only reconstruction + KL loss. " + "Saves the best checkpoint to output_dir/vae_best.ckpt. " + "Pass that path to compress_neurons_hippie to encode new data. " + "Requires avg_waveform in neuron_attributes — run get_waveform_traces first." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "output_dir": { + "type": "string", + "description": "Directory to save the best checkpoint (vae_best.ckpt)", + }, + "z_dim": { + "type": "integer", + "description": "Latent space dimensionality (default 30)", + "default": 30, + }, + "n_epochs": { + "type": "integer", + "description": "Number of training epochs (default 100)", + "default": 100, + }, + "batch_size": { + "type": "integer", + "description": "Minibatch size (default 256)", + "default": 256, + }, + "learning_rate": { + "type": "number", + "description": "AdamW learning rate (default 1e-3)", + "default": 0.001, + }, + "val_fraction": { + "type": "number", + "description": "Fraction of neurons held out for validation (default 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device: 'cpu' or 'cuda'", + "default": "cpu", + }, + }, + "required": ["workspace_id", "namespace", "output_dir"], + }, + ), + types.Tool( + name="compress_neurons_hippie", + description=( + "Compress neurons using a trained unconditioned VAE " + "(requires spikelab[hippie]). " + "Encodes all neurons into the VAE latent space, optionally runs " + "UMAP (2-D projection) and HDBSCAN clustering, then stores results " + "as neuron_attributes: vae_embedding, vae_umap_x, vae_umap_y, " + "vae_cluster. Train the model first with train_vae_hippie." + ), + inputSchema={ + "type": "object", + "properties": { + **_WS_PROPS, + "checkpoint_path": { + "type": "string", + "description": "Path to the .ckpt file saved by train_vae_hippie", + }, + "run_umap": { + "type": "boolean", + "description": "Compute 2-D UMAP projection (default true)", + "default": True, + }, + "run_hdbscan": { + "type": "boolean", + "description": "Cluster with HDBSCAN (default true)", + "default": True, + }, + "min_cluster_size": { + "type": "integer", + "description": "Minimum neurons per HDBSCAN cluster (default 5)", + "default": 5, + }, + "umap_n_neighbors": { + "type": "integer", + "description": "UMAP neighbourhood size (default 30)", + "default": 30, + }, + "umap_min_dist": { + "type": "number", + "description": "UMAP minimum distance (default 0.1)", + "default": 0.1, + }, + "device": { + "type": "string", + "description": "PyTorch device: 'cpu' or 'cuda'", + "default": "cpu", + }, + }, + "required": ["workspace_id", "namespace", "checkpoint_path"], + }, + ), + ] + ) + # ----------------------------------------------------------------------- # Workspace management tools # ----------------------------------------------------------------------- @@ -4217,6 +4404,10 @@ async def _list_tools() -> list[types.Tool]: "slice_trend": analysis.slice_trend, "slice_stability": analysis.slice_stability, "pairwise_tests": analysis.pairwise_tests, + # HIPPIE cell-type classification and VAE compression + "classify_neurons_hippie": analysis.classify_neurons_hippie, + "train_vae_hippie": analysis.train_vae_hippie, + "compress_neurons_hippie": analysis.compress_neurons_hippie, # Export tools "export_to_hdf5_raster": exporters.export_to_hdf5_raster, "export_to_hdf5_ragged": exporters.export_to_hdf5_ragged, diff --git a/src/spikelab/mcp_server/tools/analysis.py b/src/spikelab/mcp_server/tools/analysis.py index 9d0dbbac..9182cdd9 100644 --- a/src/spikelab/mcp_server/tools/analysis.py +++ b/src/spikelab/mcp_server/tools/analysis.py @@ -451,14 +451,18 @@ async def compute_spike_trig_pop_rate( """Compute spike-triggered population rate and coupling stats and store to workspace.""" ws = _get_workspace(workspace_id) sd = _get_spikedata(ws, namespace) - stPR_filtered, coupling_zero_lag, coupling_max, delays, lags = ( - sd.compute_spike_trig_pop_rate( - window_ms=window_ms, - cutoff_hz=cutoff_hz, - fs=fs, - bin_size=bin_size, - cut_outer=cut_outer, - ) + ( + stPR_filtered, + coupling_zero_lag, + coupling_max, + delays, + lags, + ) = sd.compute_spike_trig_pop_rate( + window_ms=window_ms, + cutoff_hz=cutoff_hz, + fs=fs, + bin_size=bin_size, + cut_outer=cut_outer, ) # Store stPR (U, T) and lags (T,) separately; combine coupling stats as (3, U) coupling_stack = np.stack( @@ -1307,13 +1311,17 @@ async def compute_rate_slice_unit_order( ws = _get_workspace(workspace_id) rss = _get_rateslicestack(ws, namespace, stack_key) frac_active = _get_optional_frac_active(ws, namespace, frac_active_key) - _, unit_ids_in_order, unit_std_indices, unit_peak_times, unit_frac_active = ( - rss.order_units_across_slices( - agg_func, - MIN_RATE_THRESHOLD=min_rate_threshold, - MIN_FRAC_ACTIVE=min_frac_active, - frac_active=frac_active, - ) + ( + _, + unit_ids_in_order, + unit_std_indices, + unit_peak_times, + unit_frac_active, + ) = rss.order_units_across_slices( + agg_func, + MIN_RATE_THRESHOLD=min_rate_threshold, + MIN_FRAC_ACTIVE=min_frac_active, + frac_active=frac_active, ) # Each element is a tuple of two arrays (highly_active, low_active) return { @@ -3061,3 +3069,221 @@ async def pairwise_tests( if out_key: response["key"] = out_key return response + + +# --------------------------------------------------------------------------- +# HIPPIE cell-type classification (optional — requires spikelab[hippie]) +# --------------------------------------------------------------------------- + + +def _store_hippie_result(ws, sd, workspace_id, namespace, result, prefix): + """Write embeddings / UMAP / cluster labels to neuron_attributes and return summary. + + Shared by ``classify_neurons_hippie`` and ``compress_neurons_hippie``; differs + only in the attribute-name prefix (``"hippie"`` vs ``"vae"``). Re-stores the + SpikeData in the workspace so downstream tools see the new attributes. + """ + sd.set_neuron_attribute(f"{prefix}_embedding", result["embeddings"].tolist()) + added_attrs = [f"{prefix}_embedding"] + if "umap_coords" in result: + sd.set_neuron_attribute( + f"{prefix}_umap_x", result["umap_coords"][:, 0].tolist() + ) + sd.set_neuron_attribute( + f"{prefix}_umap_y", result["umap_coords"][:, 1].tolist() + ) + added_attrs += [f"{prefix}_umap_x", f"{prefix}_umap_y"] + if "cluster_labels" in result: + sd.set_neuron_attribute(f"{prefix}_cluster", result["cluster_labels"].tolist()) + added_attrs.append(f"{prefix}_cluster") + + ws.store(namespace, "spikedata", sd) + + labels = result.get("cluster_labels") + n_clusters = ( + int(np.unique(labels[labels >= 0]).size) if labels is not None else None + ) + n_noise = int((labels < 0).sum()) if labels is not None else None + + return { + "workspace_id": workspace_id, + "namespace": namespace, + "n_neurons": int(result["embeddings"].shape[0]), + "embedding_dim": int(result["embeddings"].shape[1]), + "umap_computed": "umap_coords" in result, + "hdbscan_computed": "cluster_labels" in result, + "n_clusters": n_clusters, + "n_noise_neurons": n_noise, + "neuron_attributes_added": added_attrs, + } + + +async def classify_neurons_hippie( + workspace_id: str, + namespace: str, + tech_id: int = 0, + run_umap: bool = True, + run_hdbscan: bool = True, + min_cluster_size: int = 5, + umap_n_neighbors: int = 30, + umap_min_dist: float = 0.1, + device: str = "cpu", + cache_dir: Optional[str] = None, +) -> Dict[str, Any]: + """Classify neurons using the pretrained HIPPIE model (requires spikelab[hippie]). + + Downloads the HIPPIE checkpoint from HuggingFace, encodes all neurons into + a 30-dimensional latent space, and optionally runs UMAP projection and + HDBSCAN clustering. Results are stored back into the workspace as + neuron_attributes and as a workspace item. + + Requires avg_waveform to be present in neuron_attributes — run + get_waveform_traces first if raw data is available. + + Args: + workspace_id: Workspace ID. + namespace: Recording namespace. + tech_id: Recording technology index (0=neuropixels, 1=silicon_probe, + 2=juxtacellular, 3=tetrodes). + run_umap: Compute 2-D UMAP projection and store coordinates. + run_hdbscan: Cluster with HDBSCAN (-1 = noise). + min_cluster_size: Minimum neurons per HDBSCAN cluster. + umap_n_neighbors: UMAP neighbourhood size. + umap_min_dist: UMAP minimum distance between points. + device: "cuda" or "cpu" for the HIPPIE encoder. + cache_dir: Directory to cache the downloaded checkpoint. + """ + from ....spikedata.hippie_adapter import classify_neurons + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + umap_kwargs = {"n_neighbors": umap_n_neighbors, "min_dist": umap_min_dist} + hdbscan_kwargs = {"min_cluster_size": min_cluster_size} + + result = classify_neurons( + sd, + tech_id=tech_id, + device=device, + run_umap=run_umap, + run_hdbscan=run_hdbscan, + umap_kwargs=umap_kwargs, + hdbscan_kwargs=hdbscan_kwargs, + cache_dir=cache_dir, + ) + + return _store_hippie_result(ws, sd, workspace_id, namespace, result, "hippie") + + +# --------------------------------------------------------------------------- +# Unconditioned VAE: training + compression (requires spikelab[hippie]) +# --------------------------------------------------------------------------- + + +async def train_vae_hippie( + workspace_id: str, + namespace: str, + output_dir: str, + z_dim: int = 30, + n_epochs: int = 100, + batch_size: int = 256, + learning_rate: float = 1e-3, + val_fraction: float = 0.1, + device: str = "cpu", +) -> Dict[str, Any]: + """Train an unconditioned multimodal VAE on a SpikeData object (requires spikelab[hippie]). + + Uses the same ResNet18 + fusion encoder architecture as the pretrained HIPPIE + model but removes all class and technology conditioning. The VAE learns to + compress waveform + ISI + autocorrelogram into a z_dim-dimensional latent + space using only reconstruction + KL loss (beta=1). + + The best checkpoint is saved to output_dir/vae_best.ckpt. Pass this path + to compress_neurons_hippie to encode new data. + + Requires avg_waveform in neuron_attributes — run get_waveform_traces first. + """ + import os + + from ....spikedata.hippie_adapter import train_vae_on_spikedata + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + # Fail fast if output_dir is unwritable — VAE training can take hours, so + # surface permission / typo errors before the run starts rather than after. + os.makedirs(output_dir, exist_ok=True) + probe = os.path.join(output_dir, ".write_probe") + try: + with open(probe, "w") as fh: + fh.write("") + os.remove(probe) + except OSError as e: + raise OSError(f"output_dir is not writable: {output_dir!r} ({e})") from e + + train_vae_on_spikedata( + sd, + output_dir=output_dir, + z_dim=z_dim, + n_epochs=n_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + val_fraction=val_fraction, + device=device, + ) + + ckpt_path = os.path.join(output_dir, "vae_best.ckpt") + return { + "workspace_id": workspace_id, + "namespace": namespace, + "checkpoint_path": ckpt_path, + "z_dim": z_dim, + "n_epochs": n_epochs, + "n_neurons_trained_on": sd.N, + } + + +async def compress_neurons_hippie( + workspace_id: str, + namespace: str, + checkpoint_path: str, + run_umap: bool = True, + run_hdbscan: bool = True, + min_cluster_size: int = 5, + umap_n_neighbors: int = 30, + umap_min_dist: float = 0.1, + device: str = "cpu", +) -> Dict[str, Any]: + """Compress neurons with a trained unconditioned VAE (requires spikelab[hippie]). + + Encodes all neurons into the VAE latent space, optionally runs UMAP and + HDBSCAN, then writes results into neuron_attributes: + vae_embedding, vae_umap_x, vae_umap_y, vae_cluster. + + Args: + workspace_id: Workspace ID. + namespace: Recording namespace. + checkpoint_path: Path to the .ckpt file saved by train_vae_hippie. + run_umap: Compute 2-D UMAP projection. + run_hdbscan: Cluster with HDBSCAN (-1 = noise). + min_cluster_size: Minimum neurons per cluster. + umap_n_neighbors: UMAP neighbourhood size. + umap_min_dist: UMAP minimum distance. + device: "cuda" or "cpu". + """ + from ....spikedata.hippie_adapter import compress_neurons + + ws = _get_workspace(workspace_id) + sd = _get_spikedata(ws, namespace) + + result = compress_neurons( + sd, + compressor=checkpoint_path, + run_umap=run_umap, + run_hdbscan=run_hdbscan, + umap_kwargs={"n_neighbors": umap_n_neighbors, "min_dist": umap_min_dist}, + hdbscan_kwargs={"min_cluster_size": min_cluster_size}, + device=device, + ) + + return _store_hippie_result(ws, sd, workspace_id, namespace, result, "vae") diff --git a/src/spikelab/spikedata/__init__.py b/src/spikelab/spikedata/__init__.py index be273f16..ed9265a1 100644 --- a/src/spikelab/spikedata/__init__.py +++ b/src/spikelab/spikedata/__init__.py @@ -4,3 +4,4 @@ from .ratedata import * from .rateslicestack import * from .spikeslicestack import * +from .hippie_adapter import * # noqa: F401,F403 — module is import-safe without [hippie] diff --git a/src/spikelab/spikedata/hippie_adapter.py b/src/spikelab/spikedata/hippie_adapter.py new file mode 100644 index 00000000..8790afb2 --- /dev/null +++ b/src/spikelab/spikedata/hippie_adapter.py @@ -0,0 +1,374 @@ +"""HIPPIE neuron classification adapter for SpikeData. + +Requires: pip install spikelab[hippie] + +Workflow: + 1. extract_features(sd) — compute wave/ISI/ACG arrays from a SpikeData object + 2. classify_neurons(sd) — one-call pipeline: embed → UMAP → HDBSCAN +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +__all__ = [ + "extract_features", + "classify_neurons", + "train_vae_on_spikedata", + "compress_neurons", +] + + +# ISI histogram parameters (log-spaced, matching HIPPIE training data convention) +_ISI_N_BINS = 100 +_ISI_MIN_MS = 1.0 +_ISI_MAX_MS = 5000.0 + +# ACG parameters +_ACG_N_BINS = 100 +_ACG_MAX_LAG_MS = 100.0 + + +def _require_hippie(): + try: + import hippie # noqa: F401 + except ImportError: + raise ImportError( + "HIPPIE is required for neuron classification. " + "Install it with: pip install spikelab[hippie]" + ) + + +# ------------------------------------------------------------------ +# Per-neuron preprocessing helpers +# ------------------------------------------------------------------ + + +def _preprocess_waveform(wave: np.ndarray, target: int = 50) -> np.ndarray: + """Resample waveform to target length and min-max normalize to [-1, 1]. + + Flat waveforms (dead channels, clipped recordings) collapse to all -1.0 + via the epsilon-protected divisor, putting those units in a deterministic + noise corner of the HIPPIE latent space. + """ + import torch + import torch.nn.functional as F + + t = torch.as_tensor(wave, dtype=torch.float32).view(1, 1, -1) + t = F.interpolate(t, size=(target,), mode="linear", align_corners=False).squeeze() + mn, mx = t.min().item(), t.max().item() + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). + t = (t - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 + return t.numpy().astype(np.float32) + + +def _isi_histogram(spike_times: np.ndarray, n_bins: int = _ISI_N_BINS) -> np.ndarray: + """Compute a log-spaced ISI histogram, log(x+1)-transformed and min-max normalized. + + Spike times must be in milliseconds (SpikeLab convention). + """ + isis_ms = np.diff(np.sort(spike_times)) + isis_ms = isis_ms[isis_ms > 0] + if len(isis_ms) < 2: + return np.full( + n_bins, -1.0, dtype=np.float32 + ) # return flat -1 for silent neurons + + bins = np.logspace(np.log10(_ISI_MIN_MS), np.log10(_ISI_MAX_MS), n_bins + 1) + hist, _ = np.histogram(isis_ms, bins=bins, density=True) + hist = hist.astype(np.float32) + + # log(x+1) transform then min-max to [-1, 1] — matches MultiModalEphysDataset + hist = np.log1p(hist) + mn, mx = hist.min(), hist.max() + if mx > mn: + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). + hist = (hist - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 + return hist + + +def _autocorrelogram( + spike_times: np.ndarray, + max_lag_ms: float = _ACG_MAX_LAG_MS, + n_bins: int = _ACG_N_BINS, +) -> np.ndarray: + """Compute a half-sided autocorrelogram (forward lags only), min-max normalized. + + Spike times must be in milliseconds (SpikeLab convention). + """ + if len(spike_times) < 2: + return np.zeros(n_bins, dtype=np.float32) + + st = np.sort(spike_times) + bin_edges = np.linspace(0.0, max_lag_ms, n_bins + 1) + counts = np.zeros(n_bins, dtype=np.float64) + + for i in range(len(st)): + hi = np.searchsorted(st, st[i] + max_lag_ms, side="right") + lo = i + 1 + if lo < hi: + diffs = st[lo:hi] - st[i] + counts += np.histogram(diffs, bins=bin_edges)[0] + + total = counts.sum() + if total > 0: + counts /= total + + acg = counts.astype(np.float32) + mn, mx = acg.min(), acg.max() + if mx > mn: + # Epsilon matches the HIPPIE training-pipeline normalisation (do not strip). + acg = (acg - mn) / (mx - mn + 1e-8) * 2.0 - 1.0 + return acg + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + + +def extract_features( + sd, + isi_bins: int = _ISI_N_BINS, + acg_bins: int = _ACG_N_BINS, + acg_max_lag_ms: float = _ACG_MAX_LAG_MS, +) -> dict: + """Extract waveform, ISI, and ACG features from a SpikeData object. + + Waveforms are read from ``neuron_attributes["avg_waveform"]``. Call + ``sd.get_waveform_traces()`` first if raw_data is available and + avg_waveform has not yet been computed. + + Args: + sd: SpikeData instance with spike trains and avg_waveform attributes. + isi_bins: Number of log-spaced ISI histogram bins. + acg_bins: Number of autocorrelogram bins. + acg_max_lag_ms: Maximum lag for the autocorrelogram (milliseconds). + + Returns: + dict with keys: + - "wave": (N, 50) min-max normalized waveforms + - "isi": (N, 100) log-transformed, normalized ISI histograms + - "acg": (N, 100) normalized autocorrelograms + """ + _require_hippie() + waves = sd.get_neuron_attribute("avg_waveform") + if waves is None or any(w is None for w in waves): + raise ValueError( + "avg_waveform not found in neuron_attributes. " + "Call sd.get_waveform_traces() first, or set avg_waveform manually." + ) + + wave_arr = np.stack([_preprocess_waveform(np.asarray(w)) for w in waves]) + isi_arr = np.stack([_isi_histogram(t, n_bins=isi_bins) for t in sd.train]) + acg_arr = np.stack( + [ + _autocorrelogram(t, max_lag_ms=acg_max_lag_ms, n_bins=acg_bins) + for t in sd.train + ] + ) + + return {"wave": wave_arr, "isi": isi_arr, "acg": acg_arr} + + +def classify_neurons( + sd, + repo_id: str = "Jesusgf23/hippie", + tech_id: Union[int, str] = 0, + device: str = "cpu", + run_umap: bool = True, + run_hdbscan: bool = True, + umap_kwargs: Optional[dict] = None, + hdbscan_kwargs: Optional[dict] = None, + batch_size: int = 256, + cache_dir: Optional[str] = None, +) -> dict: + """Classify neurons in a SpikeData object using HIPPIE. + + Downloads the pretrained HIPPIE checkpoint, encodes all neurons into + the latent space, and optionally runs UMAP dimensionality reduction + followed by HDBSCAN clustering. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + repo_id: HuggingFace repository ID for the HIPPIE checkpoint. + tech_id: Recording technology — int index or one of: + "neuropixels" (0), "silicon_probe" (1), + "juxtacellular" (2), "tetrodes" (3). + device: "cuda" or "cpu". + run_umap: Compute 2-D UMAP projection of the embeddings. + run_hdbscan: Cluster with HDBSCAN (applied on UMAP coords when + run_umap=True, otherwise on raw embeddings). + umap_kwargs: Extra keyword arguments for HIPPIEClassifier.umap_reduce(). + hdbscan_kwargs: Extra keyword arguments for HIPPIEClassifier.hdbscan_cluster(). + batch_size: Neurons per forward pass. + cache_dir: Local directory to cache the downloaded checkpoint. + + Returns: + dict with keys: + - "embeddings": (N, 30) latent z_mean vectors + - "umap_coords": (N, 2) UMAP coordinates (present if run_umap=True) + - "cluster_labels":(N,) HDBSCAN labels, -1=noise (present if run_hdbscan=True) + + Example: + >>> from spikelab.spikedata.hippie_adapter import classify_neurons + >>> result = classify_neurons(sd, tech_id="neuropixels") + >>> sd.set_neuron_attribute("hippie_cluster", result["cluster_labels"]) + >>> sd.set_neuron_attribute("hippie_embedding", result["embeddings"]) + """ + _require_hippie() + from hippie.inference import HIPPIEClassifier + + features = extract_features(sd) + + clf = HIPPIEClassifier.from_pretrained( + repo_id=repo_id, device=device, cache_dir=cache_dir + ) + embeddings = clf.get_embeddings( + features["wave"], + features["isi"], + features["acg"], + tech_id=tech_id, + batch_size=batch_size, + ) + + result: dict = {"embeddings": embeddings} + + if run_umap: + result["umap_coords"] = clf.umap_reduce(embeddings, **(umap_kwargs or {})) + + if run_hdbscan: + cluster_input = result.get("umap_coords", embeddings) + result["cluster_labels"] = clf.hdbscan_cluster( + cluster_input, **(hdbscan_kwargs or {}) + ) + + return result + + +# --------------------------------------------------------------------------- +# Unconditioned VAE: training + compression +# --------------------------------------------------------------------------- + + +def train_vae_on_spikedata( + sd, + output_dir: str, + z_dim: int = 30, + n_epochs: int = 100, + batch_size: int = 256, + learning_rate: float = 1e-3, + weight_decay: float = 1e-2, + val_fraction: float = 0.1, + device: str = "cpu", + random_state: int = 42, +) -> "VAECompressor": + """Train an unconditioned multimodal VAE on a SpikeData object. + + Extracts waveform, ISI, and ACG features from sd, then trains a VAE + (no class or technology conditioning) to compress the data. The best + checkpoint is saved to output_dir and a ready VAECompressor is returned. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + output_dir: Directory to save the best checkpoint (vae_best.ckpt). + z_dim: Latent space dimensionality (default 30). + n_epochs: Training epochs. + batch_size: Minibatch size. + learning_rate: AdamW learning rate. + weight_decay: AdamW weight decay. + val_fraction: Fraction of neurons held out for validation. + device: "cuda" or "cpu". + random_state: Reproducibility seed. + + Returns: + VAECompressor loaded from the best checkpoint. + + Example: + >>> from spikelab.spikedata.hippie_adapter import train_vae_on_spikedata + >>> compressor = train_vae_on_spikedata(sd, output_dir="./my_vae", n_epochs=50) + >>> result = compress_neurons(sd, compressor) + """ + _require_hippie() + from hippie.vae import train_vae + + features = extract_features(sd) + return train_vae( + wave=features["wave"], + isi=features["isi"], + acg=features["acg"], + output_dir=output_dir, + z_dim=z_dim, + n_epochs=n_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + val_fraction=val_fraction, + device=device, + random_state=random_state, + ) + + +def compress_neurons( + sd, + compressor: Union[str, "VAECompressor"], + run_umap: bool = True, + run_hdbscan: bool = True, + umap_kwargs: Optional[dict] = None, + hdbscan_kwargs: Optional[dict] = None, + batch_size: int = 256, + device: str = "cpu", +) -> dict: + """Compress neurons with a trained unconditioned VAE. + + Args: + sd: SpikeData with spike trains and avg_waveform in neuron_attributes. + compressor: A VAECompressor instance or a path to a checkpoint (.ckpt). + run_umap: Compute 2-D UMAP projection of the embeddings. + run_hdbscan: Cluster with HDBSCAN on UMAP coords (or raw embeddings). + umap_kwargs: Extra kwargs forwarded to VAECompressor.umap_reduce(). + hdbscan_kwargs: Extra kwargs forwarded to VAECompressor.hdbscan_cluster(). + batch_size: Neurons per forward pass. + device: "cuda" or "cpu" (only used when loading from a checkpoint path). + + Returns: + dict with keys: + - "embeddings": (N, z_dim) latent z_mean vectors + - "umap_coords": (N, 2) UMAP coordinates (if run_umap=True) + - "cluster_labels":(N,) HDBSCAN labels, -1=noise (if run_hdbscan=True) + + Example: + >>> result = compress_neurons(sd, "./my_vae/vae_best.ckpt") + >>> sd.set_neuron_attribute("vae_cluster", result["cluster_labels"]) + >>> sd.set_neuron_attribute("vae_umap_x", result["umap_coords"][:, 0]) + >>> sd.set_neuron_attribute("vae_umap_y", result["umap_coords"][:, 1]) + >>> sd.set_neuron_attribute("vae_embedding", result["embeddings"]) + """ + _require_hippie() + from hippie.vae import VAECompressor + + if isinstance(compressor, (str, Path)): + compressor = VAECompressor.from_checkpoint(compressor, device=device) + + features = extract_features(sd) + embeddings = compressor.get_embeddings( + features["wave"], features["isi"], features["acg"], batch_size=batch_size + ) + + result: dict = {"embeddings": embeddings} + + if run_umap: + result["umap_coords"] = compressor.umap_reduce( + embeddings, **(umap_kwargs or {}) + ) + + if run_hdbscan: + cluster_input = result.get("umap_coords", embeddings) + result["cluster_labels"] = compressor.hdbscan_cluster( + cluster_input, **(hdbscan_kwargs or {}) + ) + + return result diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index e2cd660d..bba55393 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -5601,6 +5601,80 @@ def __init__(self, *args, **kwargs): self._assert_h5py_fallback_warning(recwarn, OSError) assert sd.N == 2 + def test_pynwb_mid_loop_failure_does_not_leak_into_h5py_fallback( + self, tmp_path, monkeypatch + ): + """ + Regression for the partial-population bug in the pynwb→h5py fallback. + + When pynwb starts successfully and populates ``trains`` / + ``neuron_attributes`` for the first few units, then raises one of + the caught exceptions mid-iteration, the loader must reset both + lists before delegating to the h5py path. Otherwise the h5py + ``extend`` lands on top of the partial pynwb state and + ``SpikeData.N`` ends up as ``partial + h5py_count`` instead of + ``h5py_count``. + + Tests: + (Test Case 1) ``sd.N`` matches the h5py-only count (2), not + the buggy ``partial + h5py_count`` (4). + (Test Case 2) ``sd.train`` contents come from the h5py + layout, confirming pynwb partial data was discarded. + """ + pynwb = pytest.importorskip("pynwb") + from types import SimpleNamespace + + class _MidLoopRaisingDataFrame: + """Iterates K rows successfully, then raises ``KeyError``.""" + + columns: list = [] # no candidate channel columns + + def __init__(self, k_good: int): + self.k_good = k_good + + def itertuples(self): + for i in range(self.k_good): + yield SimpleNamespace( + Index=i, + spike_times=np.array([0.001 * (i + 1)]), + ) + raise KeyError("simulated mid-loop pynwb read failure") + + class _FakeIO: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def read(self): + return SimpleNamespace( + units=SimpleNamespace( + to_dataframe=lambda: _MidLoopRaisingDataFrame(k_good=2) + ), + electrodes=None, + ) + + monkeypatch.setattr(pynwb, "NWBHDF5IO", _FakeIO) + + path = str(tmp_path / "test.nwb") + self._write_valid_nwb_h5py_layout(path) # 2 units in h5py layout + + with pytest.warns(UserWarning) as recwarn: + sd = loaders.load_spikedata_from_nwb(path) + self._assert_h5py_fallback_warning(recwarn, KeyError) + + # Without the fix, sd.N would be 4 (2 partial pynwb + 2 h5py). + assert sd.N == 2, ( + f"pynwb partial state leaked into h5py fallback: got N={sd.N}, " + f"expected 2" + ) + np.testing.assert_array_equal(sd.train[0], np.array([100.0, 200.0])) + np.testing.assert_array_equal(sd.train[1], np.array([500.0])) + class TestLoadSpikedataFromHdf5RawThresholdedFsHzValidation: """``load_spikedata_from_hdf5_raw_thresholded`` must validate diff --git a/tests/test_hippie_adapter.py b/tests/test_hippie_adapter.py new file mode 100644 index 00000000..a1bb98f4 --- /dev/null +++ b/tests/test_hippie_adapter.py @@ -0,0 +1,225 @@ +"""Tests for the HIPPIE cell-type classification adapter. + +All tests mock the HuggingFace download and model forward pass so the +293 MB checkpoint is never fetched during CI. +""" + +import numpy as np +import pytest +from unittest.mock import MagicMock, patch + +# Skip every test in this file if hippie is not installed +hippie = pytest.importorskip("hippie", reason="spikelab[hippie] not installed") + +from spikelab.spikedata.hippie_adapter import ( + _isi_histogram, + _autocorrelogram, + _preprocess_waveform, + extract_features, + classify_neurons, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +_DURATION_MS = 60_000.0 # 60 s recording, expressed in ms (SpikeLab convention) + + +def _make_spike_train(n_spikes=200, duration_ms=_DURATION_MS, seed=0): + rng = np.random.default_rng(seed) + return np.sort(rng.uniform(0, duration_ms, n_spikes)) + + +def _make_waveform(n=82, seed=0): + rng = np.random.default_rng(seed) + t = np.linspace(0, 2 * np.pi, n) + return np.sin(t) + rng.normal(0, 0.05, n) + + +def _make_spikedata(n_units=10, seed=0): + """Return a minimal SpikeData with avg_waveform in neuron_attributes.""" + from spikelab.spikedata import SpikeData + + trains = [_make_spike_train(200 + i * 10, seed=seed + i) for i in range(n_units)] + waveforms = [_make_waveform(seed=seed + i) for i in range(n_units)] + attrs = [{"avg_waveform": w} for w in waveforms] + return SpikeData(trains, length=_DURATION_MS, neuron_attributes=attrs) + + +# --------------------------------------------------------------------------- +# Unit tests for preprocessing helpers +# --------------------------------------------------------------------------- + + +class TestPreprocessWaveform: + def test_output_shape(self): + wave = _make_waveform(82) + out = _preprocess_waveform(wave, target=50) + assert out.shape == (50,) + + def test_range(self): + wave = _make_waveform(82) + out = _preprocess_waveform(wave) + assert out.min() >= -1.0 - 1e-5 + assert out.max() <= 1.0 + 1e-5 + + def test_flat_waveform_does_not_crash(self): + wave = np.zeros(50) + out = _preprocess_waveform(wave) + assert out.shape == (50,) + assert np.isfinite(out).all() + + def test_nonzero_flat_waveform_collapses_to_minus_one(self): + # Dead channel / clipped recording: a constant non-zero waveform + # should normalize to all -1.0 (not pass through as the raw constant) + # so those units land in a deterministic noise corner of the latent + # space and cluster together rather than scattering on out-of-range + # values. Regression guard for the W2 fix discussed in PR #120. + wave = np.full(50, -50.0) + out = _preprocess_waveform(wave) + assert out.shape == (50,) + np.testing.assert_allclose(out, -1.0, atol=1e-5) + + +class TestISIHistogram: + def test_output_shape(self): + st = _make_spike_train(200) + hist = _isi_histogram(st, n_bins=100) + assert hist.shape == (100,) + + def test_range(self): + st = _make_spike_train(200) + hist = _isi_histogram(st) + assert hist.min() >= -1.0 - 1e-5 + assert hist.max() <= 1.0 + 1e-5 + + def test_silent_neuron(self): + hist = _isi_histogram(np.array([0.5]), n_bins=100) + assert hist.shape == (100,) + assert np.isfinite(hist).all() + + def test_populated_for_realistic_train(self): + # Regression guard: spike times in ms should land inside the + # [1, 5000] ms log-spaced bin range and produce a histogram with + # genuine structure (more than two distinct values after + # normalization). A unit-conversion bug pushing ISIs out of range + # collapses the histogram to a flat output and trips this check. + st = _make_spike_train(500) + hist = _isi_histogram(st) + assert np.unique(hist).size > 2 + + +class TestAutocorrelogram: + def test_output_shape(self): + st = _make_spike_train(200) + acg = _autocorrelogram(st, n_bins=100) + assert acg.shape == (100,) + + def test_range(self): + st = _make_spike_train(200) + acg = _autocorrelogram(st) + assert acg.min() >= -1.0 - 1e-5 + assert acg.max() <= 1.0 + 1e-5 + + def test_empty_train(self): + acg = _autocorrelogram(np.array([]), n_bins=100) + assert acg.shape == (100,) + assert (acg == 0).all() + + def test_populated_for_realistic_train(self): + # Regression guard: with a high enough rate that some pairs fall + # inside max_lag_ms=100, the ACG must have non-zero counts. A + # unit-conversion bug puts every lag outside the window and + # collapses the ACG to all zeros (or constant). + st = _make_spike_train(2000) + acg = _autocorrelogram(st) + assert np.unique(acg).size > 1 + + +# --------------------------------------------------------------------------- +# extract_features +# --------------------------------------------------------------------------- + + +class TestExtractFeatures: + def test_shapes(self): + sd = _make_spikedata(n_units=8) + feats = extract_features(sd) + assert feats["wave"].shape == (8, 50) + assert feats["isi"].shape == (8, 100) + assert feats["acg"].shape == (8, 100) + + def test_dtype(self): + sd = _make_spikedata(n_units=5) + feats = extract_features(sd) + for arr in feats.values(): + assert arr.dtype == np.float32 + + def test_no_waveform_raises(self): + from spikelab.spikedata import SpikeData + + trains = [_make_spike_train(100, seed=i) for i in range(3)] + sd = SpikeData(trains, length=_DURATION_MS) + with pytest.raises(ValueError, match="avg_waveform"): + extract_features(sd) + + +# --------------------------------------------------------------------------- +# classify_neurons — mocked end-to-end +# --------------------------------------------------------------------------- + + +class TestClassifyNeurons: + """Full pipeline test with the HuggingFace download and HIPPIE model mocked out.""" + + def _make_mock_classifier(self, n_neurons, z_dim=30): + mock_clf = MagicMock() + mock_clf.get_embeddings.return_value = np.random.randn(n_neurons, z_dim).astype( + np.float32 + ) + mock_clf.umap_reduce.return_value = np.random.randn(n_neurons, 2).astype( + np.float32 + ) + mock_clf.hdbscan_cluster.return_value = np.zeros(n_neurons, dtype=np.int32) + return mock_clf + + @patch("hippie.inference.HIPPIEClassifier") + def test_returns_all_keys_by_default(self, MockCls): + n = 10 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd) + assert "embeddings" in result + assert "umap_coords" in result + assert "cluster_labels" in result + + @patch("hippie.inference.HIPPIEClassifier") + def test_no_umap_no_hdbscan(self, MockCls): + n = 6 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd, run_umap=False, run_hdbscan=False) + assert "embeddings" in result + assert "umap_coords" not in result + assert "cluster_labels" not in result + + @patch("hippie.inference.HIPPIEClassifier") + def test_embedding_shape(self, MockCls): + n = 12 + MockCls.from_pretrained.return_value = self._make_mock_classifier(n, z_dim=30) + sd = _make_spikedata(n_units=n) + result = classify_neurons(sd) + assert result["embeddings"].shape == (n, 30) + + @patch("hippie.inference.HIPPIEClassifier") + def test_tech_id_string(self, MockCls): + n = 5 + mock_clf = self._make_mock_classifier(n) + MockCls.from_pretrained.return_value = mock_clf + sd = _make_spikedata(n_units=n) + classify_neurons(sd, tech_id="silicon_probe") + mock_clf.get_embeddings.assert_called_once() + call_kwargs = mock_clf.get_embeddings.call_args + assert call_kwargs.kwargs.get("tech_id") == "silicon_probe"