From 23cef31ac371d56149e82d14620da1198c58992a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 16 Dec 2025 13:19:20 +0100 Subject: [PATCH 1/3] wip: load_analyzer_from_nwb function --- .../extractors/extractor_classes.py | 2 + .../extractors/nwbextractors.py | 287 +++++++++++++++++- 2 files changed, 288 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/extractor_classes.py b/src/spikeinterface/extractors/extractor_classes.py index 4f0f586f18..7a91927ec9 100644 --- a/src/spikeinterface/extractors/extractor_classes.py +++ b/src/spikeinterface/extractors/extractor_classes.py @@ -35,6 +35,7 @@ read_nwb_recording, read_nwb_sorting, read_nwb_timeseries, + load_analyzer_from_nwb, ) from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl @@ -194,6 +195,7 @@ __all__.extend( [ "read_nwb", # convenience function for multiple nwb formats + "load_analyzer_from_nwb", "recording_extractor_full_dict", "sorting_extractor_full_dict", "event_extractor_full_dict", diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 976e752a62..069f014970 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -7,7 +7,14 @@ import numpy as np from spikeinterface import get_global_tmp_folder -from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment +from spikeinterface.core import ( + BaseRecording, + BaseRecordingSegment, + BaseSorting, + BaseSortingSegment, + SortingAnalyzer, + get_default_analyzer_extension_params, +) from spikeinterface.core.core_tools import define_function_from_class @@ -1259,6 +1266,7 @@ def _fetch_sorting_segment_info_backend( # need this for later self.units_table = units_table + self._file = open_file return unit_ids, spike_times_data, spike_times_index_data @@ -1789,3 +1797,280 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri outputs = outputs[0] return outputs + + +def load_analyzer_from_nwb( + file_path: str | Path, + t_start: float | None = None, + sampling_frequency: float | None = None, + electrical_series_path: str | None = None, + unit_table_path: str | None = None, + stream_mode: Literal["fsspec", "remfile", "zarr"] | None = None, + stream_cache_path: str | Path | None = None, + cache: bool = False, + storage_options: dict | None = None, + use_pynwb: bool = False, + group_name: str | None = None, + compute_extra: List[str] | None = ["unit_locations", "correlograms"], + compute_extra_params: dict | None = None, + verbose: bool = False, +) -> SortingAnalyzer: + import pandas as pd + from spikeinterface.metrics.template import ComputeTemplateMetrics + from spikeinterface.metrics.quality import ComputeQualityMetrics + + # try to read recording object to get the analyzer + try: + recording = NwbRecordingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + ) + except Exception: + if verbose: + print("Could not load recording, proceeding without it") + recording = None + + t_start_tmp = 0 if t_start is None else t_start + + sorting_tmp = NwbSortingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + unit_table_path=unit_table_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + t_start=t_start_tmp, + sampling_frequency=sampling_frequency, + ) + + if recording is None and t_start is None: + # re-estimate t_start from spike times + if verbose: + print("Re-estimating t_start from spike_times") + t_start_new = np.min(sorting_tmp._sorting_segments[0].spike_times_data) - 0.001 + if verbose: + print(f"Found new t_start: {t_start_new} s") + sorting = NwbSortingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + unit_table_path=unit_table_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + t_start=t_start_new, + sampling_frequency=sampling_frequency, + ) + else: + sorting = sorting_tmp + + if use_pynwb: + units = sorting.units_table + colnames = units.colnames + units = units.to_dataframe(index=True) + else: + units_dset = sorting._file["units"] + units = make_df(units_dset) + colnames = units.columns + + electrodes_indices = None + if use_pynwb: + electrodes_table = sorting._nwbfile.electrodes.to_dataframe(index=True) + if "electrodes" in colnames: + electrodes_indices = units["electrodes"] + else: + electrodes_table = make_df(sorting._file["/general/extracellular_ephys/electrodes"]) + if "electrodes" in colnames: + electrodes_indices = electrodes_indices = units["electrodes"][:] + + if electrodes_indices is not None: + # here we assume all groups are the same for each unit, so we just check one. + if "group_name" in electrodes_table.columns: + group_names = np.array([electrodes_table.iloc[int(ei[0])]["group_name"] for ei in electrodes_indices]) + if len(np.unique(group_names)) > 0: + if group_name is None: + raise Exception( + f"More than one group, use group_name option to select units. Available groups: {np.unique(group_names)}" + ) + else: + unit_mask = group_names == group_name + if verbose: + print(f"Selecting {sum(unit_mask)} / {len(units)} units from {group_name}") + sorting = sorting.select_units(unit_ids=sorting.unit_ids[unit_mask]) + units = units.loc[units.index[unit_mask]] + electrodes_indices = units["electrodes"] + + # TODO: figure out sparsity + + # handle recording if available + if recording is not None: + # check groups + group_names = np.unique(recording.get_channel_groups()) + if group_name is not None and len(group_names) > 1: + recording = recording.split_by("group")[group_name] + rec_attributes = None + else: + recording = None + rec_attributes = {} + + # get sliced electrodes table from electrode_indices union + electrode_indices_all = [] + for ei in electrodes_indices: + electrode_indices_all.extend(ei) + electrode_indices_all = np.sort(np.unique(electrode_indices_all)) + if verbose: + print(f"Found {len(electrode_indices_all)} electrodes") + electrodes_table_sliced = electrodes_table.iloc[electrode_indices_all] + if "channel_name" in electrodes_table_sliced: + channel_ids = electrodes_table_sliced["channel_name"][:] + else: + channel_ids = electrodes_table_sliced["id"][:] + num_samples = [sorting.to_spike_vector()[-1]["sample_index"]] + rec_attributes = dict( + channel_ids=channel_ids, + sampling_frequency=sorting.sampling_frequency, + num_channels=len(channel_ids), + num_samples=num_samples, + is_filtered=True, + dtype="float32", + ) + # make a probegroup + electrode_colnames = electrodes_table_sliced.columns + assert ( + "rel_x" in electrode_colnames and "rel_y" in electrode_colnames + ), "'rel_x' and 'rel_y' should be columns in the electrode name" + locations = np.array([electrodes_table_sliced["rel_x"][:], electrodes_table_sliced["rel_y"][:]]).T + probegroup = create_dummy_probegroup_from_locations(locations) + rec_attributes["probegroup"] = probegroup + + # instantiate analyzer + analyzer = SortingAnalyzer.create_memory( + sorting=sorting, recording=recording, sparsity=None, rec_attributes=rec_attributes, return_in_uV=True + ) + + # templates + if "waveform_mean" in units: + from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeRandomSpikes + + # instantiate templates + analyzer.compute("random_spikes", method="all") + + templates_ext = ComputeTemplates(sorting_analyzer=analyzer) + templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") + total_ms = templates_avg_data.shape[1] / analyzer.sampling_frequency * 1000 + template_params = get_default_analyzer_extension_params("templates") + if total_ms != template_params["ms_before"] + template_params["ms_after"]: + if verbose: + print("Guessing correct template cutouts") + template_params["ms_before"] = int(1 / 3 * total_ms) + template_params["ms_after"] = total_ms - template_params["ms_before"] + template_params["operators"] = ["average", "std"] + templates_ext.set_params(**template_params) + templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") + templates_ext.data["average"] = templates_avg_data + if "waveforms_sd" in units: + templates_std_data = np.array([t for t in units["waveform_sd"].values]).astype("float") + else: + templates_std_data = np.zeros_like(templates_avg_data) + templates_ext.data["std"] = templates_std_data + templates_ext.run_info["run_completed"] = True + + analyzer.extensions["templates"] = templates_ext + + template_metric_columns = ComputeTemplateMetrics.get_metric_columns() + quality_metric_columns = ComputeQualityMetrics.get_metric_columns() + + tm = pd.DataFrame(index=sorting.unit_ids) + qm = pd.DataFrame(index=sorting.unit_ids) + + for col in units.columns: + if col in template_metric_columns: + tm.loc[:, col] = units[col].values + if col in quality_metric_columns: + qm.loc[:, col] = units[col].values + + if len(tm.columns) > 0: + if verbose: + print("Adding template metrics") + tm_ext = ComputeTemplateMetrics(analyzer) + tm_ext.data["metrics"] = tm + tm_ext.run_info["run_completed"] = True + analyzer.extensions["template_metrics"] = tm_ext + if len(qm.columns) > 0: + if verbose: + print("Adding quality metrics") + qm_ext = ComputeQualityMetrics(analyzer) + qm_ext.data["metrics"] = qm + qm_ext.run_info["run_completed"] = True + analyzer.extensions["quality_metrics"] = qm_ext + + # compute extra required + if compute_extra is not None: + if verbose: + print(f"Computing extra extensions: {compute_extra}") + compute_extra_params = {} if compute_extra_params is None else compute_extra_params + analyzer.compute(compute_extra, **compute_extra_params) + + return analyzer + + +def create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}): + """ + Creates a "dummy" probe based on locations. + + Parameters + ---------- + locations : np.array + Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] + shape : str, default: "circle" + Electrode shapes + shape_params : dict, default: {"radius": 1} + Shape parameters + + Returns + ------- + probe : Probe + The created probe + """ + from probeinterface import Probe, ProbeGroup + + ndim = locations.shape[1] + assert ndim == 2 + probe = Probe(ndim=2) + probe.set_contacts(locations, shapes=shape, shape_params=shape_params) + probe.set_device_channel_indices(np.arange(len(probe.contact_positions))) + probe.create_auto_shape() + probegroup = ProbeGroup() + probegroup.add_probe(probe) + + return probegroup + + +def make_df(group): + """Makes pandas DataFrame from hdf5/zarr NWB group""" + import pandas as pd + + colnames = list(group.keys()) + data = {} + for col in colnames: + if "_index" in col: + continue + item = group[col][:] + if f"{col}_index" in colnames: + item = np.split(item, group[f"{col}_index"][:])[:-1] + data[col] = item + elif item.ndim > 1: + data[col] = [item_flat for item_flat in item] + else: + data[col] = item + df = pd.DataFrame(data=data) + df.set_index("id", inplace=True) + return df From fcba7d3394dba85af5ffc8a5d7076125611ca071 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Dec 2025 15:58:15 +0100 Subject: [PATCH 2/3] Make dtype/is_filtered optional attrs and suggestions from code review --- src/spikeinterface/core/sortinganalyzer.py | 4 +- src/spikeinterface/extractors/__init__.py | 1 + .../extractors/extractor_classes.py | 2 - .../extractors/nwbextractors.py | 63 ++++++++++--------- 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..d2a9cbe2df 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1501,7 +1501,7 @@ def is_sparse(self) -> bool: return self.sparsity is not None def is_filtered(self) -> bool: - return self.rec_attributes["is_filtered"] + return self.rec_attributes.get("is_filtered", True) def get_sorting_provenance(self): """ @@ -1602,7 +1602,7 @@ def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) def get_dtype(self): - return self.rec_attributes["dtype"] + return self.rec_attributes.get("dtype") def get_num_units(self) -> int: return self.sorting.get_num_units() diff --git a/src/spikeinterface/extractors/__init__.py b/src/spikeinterface/extractors/__init__.py index 216c668e0e..000c3832a5 100644 --- a/src/spikeinterface/extractors/__init__.py +++ b/src/spikeinterface/extractors/__init__.py @@ -6,6 +6,7 @@ from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts from .neoextractors import get_neo_num_blocks, get_neo_streams from .phykilosortextractors import read_kilosort_as_analyzer +from .nwbextractors import read_nwb_as_analyzer from warnings import warn diff --git a/src/spikeinterface/extractors/extractor_classes.py b/src/spikeinterface/extractors/extractor_classes.py index 7a91927ec9..4f0f586f18 100644 --- a/src/spikeinterface/extractors/extractor_classes.py +++ b/src/spikeinterface/extractors/extractor_classes.py @@ -35,7 +35,6 @@ read_nwb_recording, read_nwb_sorting, read_nwb_timeseries, - load_analyzer_from_nwb, ) from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl @@ -195,7 +194,6 @@ __all__.extend( [ "read_nwb", # convenience function for multiple nwb formats - "load_analyzer_from_nwb", "recording_extractor_full_dict", "sorting_extractor_full_dict", "event_extractor_full_dict", diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 069f014970..2c683088d1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1799,7 +1799,7 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri return outputs -def load_analyzer_from_nwb( +def read_nwb_as_analyzer( file_path: str | Path, t_start: float | None = None, sampling_frequency: float | None = None, @@ -1878,7 +1878,7 @@ def load_analyzer_from_nwb( units = units.to_dataframe(index=True) else: units_dset = sorting._file["units"] - units = make_df(units_dset) + units = _create_df_from_nwb_table(units_dset) colnames = units.columns electrodes_indices = None @@ -1887,7 +1887,7 @@ def load_analyzer_from_nwb( if "electrodes" in colnames: electrodes_indices = units["electrodes"] else: - electrodes_table = make_df(sorting._file["/general/extracellular_ephys/electrodes"]) + electrodes_table = _create_df_from_nwb_table(sorting._file["/general/extracellular_ephys/electrodes"]) if "electrodes" in colnames: electrodes_indices = electrodes_indices = units["electrodes"][:] @@ -1939,8 +1939,6 @@ def load_analyzer_from_nwb( sampling_frequency=sorting.sampling_frequency, num_channels=len(channel_ids), num_samples=num_samples, - is_filtered=True, - dtype="float32", ) # make a probegroup electrode_colnames = electrodes_table_sliced.columns @@ -1948,7 +1946,7 @@ def load_analyzer_from_nwb( "rel_x" in electrode_colnames and "rel_y" in electrode_colnames ), "'rel_x' and 'rel_y' should be columns in the electrode name" locations = np.array([electrodes_table_sliced["rel_x"][:], electrodes_table_sliced["rel_y"][:]]).T - probegroup = create_dummy_probegroup_from_locations(locations) + probegroup = _create_dummy_probegroup_from_locations(locations) rec_attributes["probegroup"] = probegroup # instantiate analyzer @@ -1960,18 +1958,22 @@ def load_analyzer_from_nwb( if "waveform_mean" in units: from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeRandomSpikes - # instantiate templates + # compute random spikes, which is a dependency for templates + # since we don't know the spike samples, we compute with method 'all' analyzer.compute("random_spikes", method="all") + # instantiate templates templates_ext = ComputeTemplates(sorting_analyzer=analyzer) templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") total_ms = templates_avg_data.shape[1] / analyzer.sampling_frequency * 1000 - template_params = get_default_analyzer_extension_params("templates") - if total_ms != template_params["ms_before"] + template_params["ms_after"]: - if verbose: - print("Guessing correct template cutouts") - template_params["ms_before"] = int(1 / 3 * total_ms) - template_params["ms_after"] = total_ms - template_params["ms_before"] + # estimate ms_before and ms_after from minimum point in the average template + nbefore = np.unravel_index(np.argmin(templates_avg_data, axis=1), templates_avg_data.shape)[1] + print(nbefore) + ms_before = int(nbefore / analyzer.sampling_frequency * 1000) + ms_after = int(total_ms - ms_before) + template_params = {} + template_params["ms_before"] = ms_before + template_params["ms_after"] = ms_after template_params["operators"] = ["average", "std"] templates_ext.set_params(**template_params) templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") @@ -1988,29 +1990,32 @@ def load_analyzer_from_nwb( template_metric_columns = ComputeTemplateMetrics.get_metric_columns() quality_metric_columns = ComputeQualityMetrics.get_metric_columns() - tm = pd.DataFrame(index=sorting.unit_ids) - qm = pd.DataFrame(index=sorting.unit_ids) + template_metrics_df = pd.DataFrame(index=sorting.unit_ids) + quality_metric_df = pd.DataFrame(index=sorting.unit_ids) for col in units.columns: if col in template_metric_columns: - tm.loc[:, col] = units[col].values + template_metrics_df.loc[:, col] = units[col].values if col in quality_metric_columns: - qm.loc[:, col] = units[col].values + quality_metric_df.loc[:, col] = units[col].values - if len(tm.columns) > 0: + if len(template_metrics_df.columns) > 0: if verbose: print("Adding template metrics") - tm_ext = ComputeTemplateMetrics(analyzer) - tm_ext.data["metrics"] = tm - tm_ext.run_info["run_completed"] = True - analyzer.extensions["template_metrics"] = tm_ext - if len(qm.columns) > 0: + template_metrics_ext = ComputeTemplateMetrics(analyzer) + template_metrics_ext.data["metrics"] = template_metrics_df + template_metrics_ext.run_info["run_completed"] = True + # cast to correct dtypes + template_metrics_ext._cast_metrics() + analyzer.extensions["template_metrics"] = template_metrics_ext + if len(quality_metric_df.columns) > 0: if verbose: print("Adding quality metrics") - qm_ext = ComputeQualityMetrics(analyzer) - qm_ext.data["metrics"] = qm - qm_ext.run_info["run_completed"] = True - analyzer.extensions["quality_metrics"] = qm_ext + quality_metrics_ext = ComputeQualityMetrics(analyzer) + quality_metrics_ext.data["metrics"] = quality_metric_df + quality_metrics_ext.run_info["run_completed"] = True + quality_metrics_ext._cast_metrics() + analyzer.extensions["quality_metrics"] = quality_metrics_ext # compute extra required if compute_extra is not None: @@ -2022,7 +2027,7 @@ def load_analyzer_from_nwb( return analyzer -def create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}): +def _create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}): """ Creates a "dummy" probe based on locations. @@ -2054,7 +2059,7 @@ def create_dummy_probegroup_from_locations(locations, shape="circle", shape_para return probegroup -def make_df(group): +def _create_df_from_nwb_table(group): """Makes pandas DataFrame from hdf5/zarr NWB group""" import pandas as pd From b72a39d8f38d9ed458693d9b9dc45390533f5d9b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Dec 2025 16:56:48 +0100 Subject: [PATCH 3/3] fix ibl tests --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [