Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def is_sparse(self) -> bool:
return self.sparsity is not None

def is_filtered(self) -> bool:
return self.rec_attributes["is_filtered"]
return self.rec_attributes.get("is_filtered", True)

def get_sorting_provenance(self):
"""
Expand Down Expand Up @@ -1602,7 +1602,7 @@ def get_sorting_property(self, key) -> np.ndarray:
return self.sorting.get_property(key)

def get_dtype(self):
return self.rec_attributes["dtype"]
return self.rec_attributes.get("dtype")

def get_num_units(self) -> int:
return self.sorting.get_num_units()
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts
from .neoextractors import get_neo_num_blocks, get_neo_streams
from .phykilosortextractors import read_kilosort_as_analyzer
from .nwbextractors import read_nwb_as_analyzer

from warnings import warn

Expand Down
292 changes: 291 additions & 1 deletion src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import numpy as np

from spikeinterface import get_global_tmp_folder
from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment
from spikeinterface.core import (
BaseRecording,
BaseRecordingSegment,
BaseSorting,
BaseSortingSegment,
SortingAnalyzer,
get_default_analyzer_extension_params,
)
from spikeinterface.core.core_tools import define_function_from_class


Expand Down Expand Up @@ -1259,6 +1266,7 @@ def _fetch_sorting_segment_info_backend(

# need this for later
self.units_table = units_table
self._file = open_file

return unit_ids, spike_times_data, spike_times_index_data

Expand Down Expand Up @@ -1789,3 +1797,285 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri
outputs = outputs[0]

return outputs


def read_nwb_as_analyzer(
file_path: str | Path,
t_start: float | None = None,
sampling_frequency: float | None = None,
electrical_series_path: str | None = None,
unit_table_path: str | None = None,
stream_mode: Literal["fsspec", "remfile", "zarr"] | None = None,
stream_cache_path: str | Path | None = None,
cache: bool = False,
storage_options: dict | None = None,
use_pynwb: bool = False,
group_name: str | None = None,
compute_extra: List[str] | None = ["unit_locations", "correlograms"],
compute_extra_params: dict | None = None,
verbose: bool = False,
) -> SortingAnalyzer:
import pandas as pd
from spikeinterface.metrics.template import ComputeTemplateMetrics
from spikeinterface.metrics.quality import ComputeQualityMetrics

# try to read recording object to get the analyzer
try:
recording = NwbRecordingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
)
except Exception:
if verbose:
print("Could not load recording, proceeding without it")
recording = None

t_start_tmp = 0 if t_start is None else t_start

sorting_tmp = NwbSortingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
unit_table_path=unit_table_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
t_start=t_start_tmp,
sampling_frequency=sampling_frequency,
)
Comment on lines +1838 to +1851
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use session_start_time instead.


if recording is None and t_start is None:
# re-estimate t_start from spike times
if verbose:
print("Re-estimating t_start from spike_times")
t_start_new = np.min(sorting_tmp._sorting_segments[0].spike_times_data) - 0.001
if verbose:
print(f"Found new t_start: {t_start_new} s")
sorting = NwbSortingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
unit_table_path=unit_table_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
t_start=t_start_new,
sampling_frequency=sampling_frequency,
)
else:
sorting = sorting_tmp

if use_pynwb:
units = sorting.units_table
colnames = units.colnames
units = units.to_dataframe(index=True)
else:
units_dset = sorting._file["units"]
units = _create_df_from_nwb_table(units_dset)
colnames = units.columns

electrodes_indices = None
if use_pynwb:
electrodes_table = sorting._nwbfile.electrodes.to_dataframe(index=True)
if "electrodes" in colnames:
electrodes_indices = units["electrodes"]
else:
electrodes_table = _create_df_from_nwb_table(sorting._file["/general/extracellular_ephys/electrodes"])
if "electrodes" in colnames:
electrodes_indices = electrodes_indices = units["electrodes"][:]

if electrodes_indices is not None:
# here we assume all groups are the same for each unit, so we just check one.
if "group_name" in electrodes_table.columns:
group_names = np.array([electrodes_table.iloc[int(ei[0])]["group_name"] for ei in electrodes_indices])
if len(np.unique(group_names)) > 0:
if group_name is None:
raise Exception(
f"More than one group, use group_name option to select units. Available groups: {np.unique(group_names)}"
)
else:
unit_mask = group_names == group_name
if verbose:
print(f"Selecting {sum(unit_mask)} / {len(units)} units from {group_name}")
sorting = sorting.select_units(unit_ids=sorting.unit_ids[unit_mask])
units = units.loc[units.index[unit_mask]]
electrodes_indices = units["electrodes"]
Comment on lines +1894 to +1909
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use the same trick as the "aggregation_key" when instantiating a sorting analyzer from grouped recordings/sortings


# TODO: figure out sparsity

# handle recording if available
if recording is not None:
# check groups
group_names = np.unique(recording.get_channel_groups())
if group_name is not None and len(group_names) > 1:
recording = recording.split_by("group")[group_name]
rec_attributes = None
else:
recording = None
rec_attributes = {}

# get sliced electrodes table from electrode_indices union
electrode_indices_all = []
for ei in electrodes_indices:
electrode_indices_all.extend(ei)
electrode_indices_all = np.sort(np.unique(electrode_indices_all))
if verbose:
print(f"Found {len(electrode_indices_all)} electrodes")
electrodes_table_sliced = electrodes_table.iloc[electrode_indices_all]
if "channel_name" in electrodes_table_sliced:
channel_ids = electrodes_table_sliced["channel_name"][:]
else:
channel_ids = electrodes_table_sliced["id"][:]
num_samples = [sorting.to_spike_vector()[-1]["sample_index"]]
rec_attributes = dict(
channel_ids=channel_ids,
sampling_frequency=sorting.sampling_frequency,
num_channels=len(channel_ids),
num_samples=num_samples,
)
# make a probegroup
electrode_colnames = electrodes_table_sliced.columns
assert (
"rel_x" in electrode_colnames and "rel_y" in electrode_colnames
), "'rel_x' and 'rel_y' should be columns in the electrode name"
locations = np.array([electrodes_table_sliced["rel_x"][:], electrodes_table_sliced["rel_y"][:]]).T
probegroup = _create_dummy_probegroup_from_locations(locations)
rec_attributes["probegroup"] = probegroup

# instantiate analyzer
analyzer = SortingAnalyzer.create_memory(
sorting=sorting, recording=recording, sparsity=None, rec_attributes=rec_attributes, return_in_uV=True
)

# templates
if "waveform_mean" in units:
from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeRandomSpikes

# compute random spikes, which is a dependency for templates
# since we don't know the spike samples, we compute with method 'all'
analyzer.compute("random_spikes", method="all")

# instantiate templates
templates_ext = ComputeTemplates(sorting_analyzer=analyzer)
templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float")
total_ms = templates_avg_data.shape[1] / analyzer.sampling_frequency * 1000
# estimate ms_before and ms_after from minimum point in the average template
nbefore = np.unravel_index(np.argmin(templates_avg_data, axis=1), templates_avg_data.shape)[1]
print(nbefore)
ms_before = int(nbefore / analyzer.sampling_frequency * 1000)
ms_after = int(total_ms - ms_before)
template_params = {}
template_params["ms_before"] = ms_before
template_params["ms_after"] = ms_after
template_params["operators"] = ["average", "std"]
templates_ext.set_params(**template_params)
templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float")
templates_ext.data["average"] = templates_avg_data
if "waveforms_sd" in units:
templates_std_data = np.array([t for t in units["waveform_sd"].values]).astype("float")
else:
templates_std_data = np.zeros_like(templates_avg_data)
templates_ext.data["std"] = templates_std_data
templates_ext.run_info["run_completed"] = True

analyzer.extensions["templates"] = templates_ext

template_metric_columns = ComputeTemplateMetrics.get_metric_columns()
quality_metric_columns = ComputeQualityMetrics.get_metric_columns()

template_metrics_df = pd.DataFrame(index=sorting.unit_ids)
quality_metric_df = pd.DataFrame(index=sorting.unit_ids)

for col in units.columns:
if col in template_metric_columns:
template_metrics_df.loc[:, col] = units[col].values
if col in quality_metric_columns:
quality_metric_df.loc[:, col] = units[col].values

if len(template_metrics_df.columns) > 0:
if verbose:
print("Adding template metrics")
template_metrics_ext = ComputeTemplateMetrics(analyzer)
template_metrics_ext.data["metrics"] = template_metrics_df
template_metrics_ext.run_info["run_completed"] = True
# cast to correct dtypes
template_metrics_ext._cast_metrics()
analyzer.extensions["template_metrics"] = template_metrics_ext
if len(quality_metric_df.columns) > 0:
if verbose:
print("Adding quality metrics")
quality_metrics_ext = ComputeQualityMetrics(analyzer)
quality_metrics_ext.data["metrics"] = quality_metric_df
quality_metrics_ext.run_info["run_completed"] = True
quality_metrics_ext._cast_metrics()
analyzer.extensions["quality_metrics"] = quality_metrics_ext

# compute extra required
if compute_extra is not None:
if verbose:
print(f"Computing extra extensions: {compute_extra}")
compute_extra_params = {} if compute_extra_params is None else compute_extra_params
analyzer.compute(compute_extra, **compute_extra_params)

return analyzer


def _create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}):
"""
Creates a "dummy" probe based on locations.

Parameters
----------
locations : np.array
Array with channel locations (num_channels, ndim) [ndim can be 2 or 3]
shape : str, default: "circle"
Electrode shapes
shape_params : dict, default: {"radius": 1}
Shape parameters

Returns
-------
probe : Probe
The created probe
"""
from probeinterface import Probe, ProbeGroup

ndim = locations.shape[1]
assert ndim == 2
probe = Probe(ndim=2)
probe.set_contacts(locations, shapes=shape, shape_params=shape_params)
probe.set_device_channel_indices(np.arange(len(probe.contact_positions)))
probe.create_auto_shape()
probegroup = ProbeGroup()
probegroup.add_probe(probe)

return probegroup


def _create_df_from_nwb_table(group):
"""Makes pandas DataFrame from hdf5/zarr NWB group"""
import pandas as pd

colnames = list(group.keys())
data = {}
for col in colnames:
if "_index" in col:
continue
item = group[col][:]
if f"{col}_index" in colnames:
item = np.split(item, group[f"{col}_index"][:])[:-1]
data[col] = item
elif item.ndim > 1:
data[col] = [item_flat for item_flat in item]
else:
data[col] = item
df = pd.DataFrame(data=data)
df.set_index("id", inplace=True)
return df
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/tests/test_iblextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_offsets(self):

def test_probe_representation(self):
probe = self.recording.get_probe()
expected_probe_representation = "Probe - 384ch - 1shanks"
assert repr(probe) == expected_probe_representation
expected_probe_representation = "Probe - 384ch"
assert expected_probe_representation in repr(probe)

def test_property_keys(self):
expected_property_keys = [
Expand Down