Skip to content
275 changes: 275 additions & 0 deletions bluecellulab/cell/point_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
from __future__ import annotations

from dataclasses import dataclass
import logging
from pathlib import Path
from typing import Any, Mapping, Optional

from bluecellulab.cell import Cell
from bluecellulab.circuit.simulation_access import get_synapse_replay_spikes
from bluecellulab.exceptions import BluecellulabError
from bluecellulab.circuit import SynapseProperty
from neuron import h
import numpy as np

from bluecellulab.circuit.node_id import CellId

logger = logging.getLogger(__name__)


class BasePointProcessCell(Cell):
"""Base class for NEURON artificial point processes (IntFire1/2/...)."""

def __init__(self, cell_id: Optional[CellId]) -> None:
if cell_id is None:
raise ValueError("PointProcessCell requires valid cell_id")
self.cell_id = cell_id

self._spike_times = h.Vector()
self._spike_detector: Optional[h.NetCon] = None
self.pointcell = None # type: ignore[assignment]
self.synapses: dict = {}
self.connections: dict = {}

@property
def hoc_cell(self):
return self.pointcell

def init_callbacks(self):
pass

def connect_to_circuit(self, proxy) -> None:
self._circuit_proxy = proxy

def delete(self) -> None:
# Stop recording
if self._spike_detector is not None:
# NetCon will be GC'd when no Python refs remain
self._spike_detector = None
if self._spike_times is not None:
self._spike_times = None

# Drop pointer to underlying NEURON object
self.pointcell = None

def get_spike_times(self) -> list[float]:
return list(self._spike_times)

def create_netcon_spikedetector(
self,
sec, # ignored for artificial cells
location=None, # ignored for artificial cells
threshold: float = 0.0,
) -> h.NetCon:
if self.pointcell is None:
raise ValueError("attempting to create netcon without valid pointprocess")
nc = h.NetCon(self.pointcell.pointcell, None)
nc.threshold = threshold # harmless for artificial cells
return nc

def is_recording_spikes(self, location=None, threshold: float | None = None) -> bool:
return self._spike_detector is not None

def start_recording_spikes(self, sec, location=None, threshold: float = 0.0) -> None:
if self._spike_detector is not None:
return
if self.pointcell is None:
raise ValueError("attempting to record spikes without valid pointprocess")
self._spike_times = h.Vector()
self._spike_detector = h.NetCon(self.pointcell.pointcell, None)
self._spike_detector.threshold = threshold
self._spike_detector.record(self._spike_times)

def connect2target(self, target_pp=None) -> h.NetCon:
"""Neurodamus-like helper: NetCon from this cell to a target point process."""
if self.pointcell is None:
raise ValueError("call to connect2target without valid pointprocess")
return h.NetCon(self.pointcell.pointcell, target_pp)


class HocPointProcessCell(BasePointProcessCell):
"""Point process that wraps an arbitrary HOC/mod artificial mechanism."""

def __init__(
self,
cell_id: Optional[CellId],
mechanism_name: str,
param_overrides: Optional[Mapping[str, Any]] = None,
spike_threshold: float = 1.0,
) -> None:
super().__init__(cell_id)

try:
mech_cls = getattr(h, mechanism_name)
except AttributeError as exc:
raise BluecellulabError(
f"Point mechanism '{mechanism_name}' not found in NEURON. "
"Make sure the mod/hoc files are compiled and loaded."
) from exc

if cell_id is None:
raise ValueError("call to create pointprocess mechanism without valid cell_id")
point = mech_cls(cell_id.id)
if param_overrides:
for name, value in param_overrides.items():
if hasattr(point, name):
setattr(point, name, value)

self.pointcell = point
self.start_recording_spikes(None, None, threshold=spike_threshold)

def add_synapse_replay(self, stimulus, spike_threshold: float, spike_location: str) -> None:
"""SONATA-style spike replay for point processes.

This is a simplified analogue of Cell.add_synapse_replay, but
instead of mapping spikes to individual synapses, we directly
connect each presynaptic node_id's spike train to this
artificial cell via VecStim → NetCon.
"""
file_path = Path(stimulus.spike_file).expanduser()

if not file_path.is_absolute():
config_dir = stimulus.config_dir
if config_dir is not None:
file_path = Path(config_dir) / file_path

file_path = file_path.resolve()

if not file_path.exists():
raise FileNotFoundError(f"Spike file not found: {str(file_path)}")

synapse_spikes = get_synapse_replay_spikes(str(file_path))

if not hasattr(self, "_replay_vecs"):
self._replay_vecs: list[h.Vector] = []
if not hasattr(self, "_replay_vecstims"):
self._replay_vecstims: list[h.VecStim] = []
if not hasattr(self, "_replay_netcons"):
self._replay_netcons: list[h.NetCon] = []

for pre_node_id, spikes in synapse_spikes.items():
delay = getattr(stimulus, "delay", 0.0) or 0.0
duration = getattr(stimulus, "duration", np.inf)

spikes_of_interest = spikes[
(spikes >= delay) & (spikes <= duration)
]
if spikes_of_interest.size == 0:
continue

vec = h.Vector(spikes_of_interest)
vs = h.VecStim()
vs.play(vec)

if self.pointcell is None:
raise ValueError("attempting to add replay spikes with valid pointprocess")
nc = h.NetCon(vs, self.pointcell.pointcell)
# Use stimulus weight if available, otherwise default to 1.0
weight = getattr(stimulus, "weight", 1.0)
nc.weight[0] = weight
nc.delay = 0.0 # delay already baked into spike times

self._replay_vecs.append(vec)
self._replay_vecstims.append(vs)
self._replay_netcons.append(nc)

logger.debug(
f"Added replay connection from pre_node_id={pre_node_id} "
f"to point neuron {self.cell_id}"
)

def add_replay_synapse(self, syn_id, syn_description, syn_connection_parameters, condition_parameters,
popids, extracellular_calcium):
"""For Point Neurons, the replay simply queues events directly to the
point obj."""
from bluecellulab.point.point_connection import PointProcessConnection
from bluecellulab.point.connection_params import PointProcessConnParameters

# syn_connection_parameters should only have 1 element, PointProcessConnection will confirm
point_params = PointProcessConnParameters(syn_description[SynapseProperty.PRE_GID], syn_description[SynapseProperty.PRE_GID],
syn_description[SynapseProperty.AXONAL_DELAY])

self.pointConn = PointProcessConnection([point_params])
self.pointConn.finalize(self.pointcell)


def mechanism_name_from_model_template(template_path: str, model_template: str) -> str:
"""Translate SONATA model_template into a NEURON mechanism name.

Examples:
'hoc:AllenPointCell' -> 'AllenPointCell'
'nrn:IntFire1' -> 'IntFire1'
'AllenPointCell' -> 'AllenPointCell'
"""
mt = str(model_template).strip()
if ":" in mt:
prefix, name = mt.split(":", 1)
prefix = prefix.lower()
if prefix in ("hoc", "nrn"):
h.load_file(template_path)
return name
return mt


@dataclass
class IntFire1Params:
tau: float = 10.0
refrac: float = 2.0


class IntFire1Cell(BasePointProcessCell):
def __init__(
self,
cell_id: Optional[CellId] = None,
tau: float = 10.0,
refrac: float = 2.0,
) -> None:
super().__init__(cell_id)
point = h.IntFire1()
point.tau = tau
point.refrac = refrac
self.pointcell = point

self.start_recording_spikes(None, None, threshold=1.0)


@dataclass
class IntFire2Params:
taum: float = 10.0
taus: float = 20.0
ib: float = 0.0


class IntFire2Cell(BasePointProcessCell):
def __init__(
self,
cell_id: Optional[CellId] = None,
taum: float = 10.0,
taus: float = 20.0,
ib: float = 0.0,
) -> None:
super().__init__(cell_id)
point = h.IntFire2()
point.taum = taum
point.taus = taus
point.ib = ib
self.pointcell = point

self.start_recording_spikes(None, None, threshold=1.0)


def create_intfire1_cell(
tau: float = 10.0,
refrac: float = 2.0,
cell_id: Optional[CellId] = None,
) -> IntFire1Cell:
return IntFire1Cell(cell_id=cell_id, tau=tau, refrac=refrac)


def create_intfire2_cell(
taum: float = 10.0,
taus: float = 20.0,
ib: float = 0.0,
cell_id: Optional[CellId] = None,
) -> IntFire2Cell:
return IntFire2Cell(cell_id=cell_id, taum=taum, taus=taus, ib=ib)
25 changes: 23 additions & 2 deletions bluecellulab/circuit/circuit_access/sonata_circuit_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def _select_edge_pop_names(self, projections) -> list[str]:
def extract_synapses(
self, cell_id: CellId, projections: Optional[list[str] | str]
) -> pd.DataFrame:
"""Extract the synapses."""
"""Extract the synapses. Checks available fields to determine which are
present in the edge file to determine the properties to extract.

If projections is None, all the synapses are extracted.
"""
snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id)
edges = self._circuit.edges

Expand All @@ -200,7 +204,10 @@ def extract_synapses(

# remove optional properties if they are not present
for optional_property in [SynapseProperty.U_HILL_COEFFICIENT,
SynapseProperty.CONDUCTANCE_RATIO]:
SynapseProperty.CONDUCTANCE_RATIO,
SynapseProperty.AFFERENT_SECTION_POS,
SynapseProperty.POST_SEGMENT_ID,
SynapseProperty.POST_SEGMENT_OFFSET]:
if optional_property.to_snap() not in edge_population.property_names:
edge_properties.remove(optional_property)

Expand All @@ -211,6 +218,20 @@ def extract_synapses(
):
edge_properties += list(SynapseProperties.plasticity)

# check for allen instance - replace the entire edge_properties list as appropriate
# properties for allen point/chemical neuron connection type edges
if SynapseProperty.TYPE not in edge_population.property_names:
if all(
x in edge_population.property_names
for x in SynapseProperties.allen_point
):
edge_properties = list(SynapseProperties.allen_point)
if all(
x in edge_population.property_names
for x in SynapseProperties.allen_chemical
):
edge_properties = list(SynapseProperties.allen_chemical)

snap_properties = properties_to_snap(edge_properties)
synapses: pd.DataFrame = edge_population.get(afferent_edges, snap_properties)
column_names = list(synapses.columns)
Expand Down
9 changes: 9 additions & 0 deletions bluecellulab/circuit/synapse_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SynapseProperty(Enum):
PRE_GID = "pre_gid"
AXONAL_DELAY = "axonal_delay"
POST_SECTION_ID = "post_section_id"
POST_SECTION_POS = "post_section_pos"
POST_SEGMENT_ID = "post_segment_id"
POST_SEGMENT_OFFSET = "post_segment_offset"
G_SYNX = "g_synx"
Expand Down Expand Up @@ -83,6 +84,14 @@ class SynapseProperties:
"volume_CR", "rho0_GB", "Use_d_TM", "Use_p_TM", "gmax_d_AMPA",
"gmax_p_AMPA", "theta_d", "theta_p"
)
allen_chemical = (
"afferent_section_id", "afferent_section_pos", "conductance", "delay", "tau1", "tau2", "erev",
"@source_node"
)
allen_point = (
"afferent_section_id", "afferent_section_pos", "conductance", "delay",
"@source_node"
)


snap_to_synproperty = MappingProxyType({
Expand Down
Loading
Loading