diff --git a/sanafe/__init__.py b/sanafe/__init__.py index 1df941d..1be4680 100644 --- a/sanafe/__init__.py +++ b/sanafe/__init__.py @@ -1,2 +1,6 @@ # Import pybind11 (C++) kernel under top-level from sanafecpp import * + +# Import Python submodules for convenient access +from sanafe import data +from sanafe import viz diff --git a/sanafe/data/__init__.py b/sanafe/data/__init__.py new file mode 100644 index 0000000..25b52e7 --- /dev/null +++ b/sanafe/data/__init__.py @@ -0,0 +1,8 @@ +""" +Provides utilities for loading, converting, and analyzing SANA-FE trace data. +Supports conversion to pandas DataFrames and numpy arrays for analysis. +""" + +from sanafe.data.traces import TraceData + +__all__ = ["TraceData"] diff --git a/sanafe/data/arch_testing.yaml b/sanafe/data/arch_testing.yaml new file mode 100644 index 0000000..720f55b --- /dev/null +++ b/sanafe/data/arch_testing.yaml @@ -0,0 +1,79 @@ +## arch_testing.yaml +architecture: + name: tutorial_complex + attributes: + link_buffer_size: 4 + width: 3 + height: 2 + tile: + - name: tutorial_tile[0..5] + attributes: + energy_north_hop: 1.0e-12 + latency_north_hop: 1.0e-9 + energy_east_hop: 1.0e-12 + latency_east_hop: 1.0e-9 + energy_south_hop: 1.0e-12 + latency_south_hop: 1.0e-9 + energy_west_hop: 1.0e-12 + latency_west_hop: 1.0e-9 + core: + - name: tutorial_core[0..7] + attributes: + buffer_position: soma + max_neurons_supported: 128 + axon_in: + - name: tutorial_axon_in + attributes: + energy_message_in: 0.5e-12 + latency_message_in: 0.5e-9 + synapse: + - name: tutorial_synapse_uncompressed + attributes: + model: current_based + energy_process_spike: 1.0e-12 + latency_process_spike: 1.0e-9 + - name: tutorial_synapse_compressed + attributes: + model: current_based + energy_process_spike: 0.5e-12 + latency_process_spike: 2.0e-9 + - name: tutorial_synapse_sparse + attributes: + model: current_based + energy_process_spike: 0.3e-12 + latency_process_spike: 3.0e-9 + dendrite: + - name: demo_dendrite_accumulator + attributes: + model: accumulator + energy_update: 0.2e-12 + latency_update: 0.2e-9 + - name: demo_dendrite_adaptive + attributes: + model: accumulator + energy_update: 0.4e-12 + latency_update: 0.4e-9 + soma: + - name: tutorial_soma_fast + attributes: + model: leaky_integrate_fire + energy_access_neuron: 1.0e-12 + latency_access_neuron: 1.0e-9 + energy_update_neuron: 2.0e-12 + latency_update_neuron: 2.0e-9 + energy_spike_out: 3.0e-12 + latency_spike_out: 3.0e-9 + - name: tutorial_soma_efficient + attributes: + model: leaky_integrate_fire + energy_access_neuron: 0.8e-12 + latency_access_neuron: 1.2e-9 + energy_update_neuron: 1.6e-12 + latency_update_neuron: 2.4e-9 + energy_spike_out: 2.4e-12 + latency_spike_out: 3.2e-9 + axon_out: + - name: tutorial_axon_out + attributes: + energy_message_out: 4.0e-12 + latency_message_out: 4.0e-9 diff --git a/sanafe/data/traces.py b/sanafe/data/traces.py new file mode 100644 index 0000000..783245c --- /dev/null +++ b/sanafe/data/traces.py @@ -0,0 +1,495 @@ +""" +Provides the TraceData class for unified access to all trace types +produced by SANA-FE simulations, with conversion methods for pandas DataFrames +and numpy arrays. +""" + +from __future__ import annotations + +import csv +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import pandas as pd + + +@dataclass +class SpikeEvent: + """ + Represents a single spike event. + + timestep: The simulation timestep when the spike occurred + group_name: Name of the neuron group containing the spiking neuron + neuron_offset: Index of the neuron within its group + """ + timestep: int + group_name: str + neuron_offset: int + + @property + def neuron_id(self) -> str: + """Full neuron identifier in 'group.offset' format.""" + return f"{self.group_name}.{self.neuron_offset}" + + +@dataclass +class TraceData: + """ + Provides unified access to spike traces, potential traces, performance + metrics, and message traces from SANA-FE simulations. Supports loading + from both in-memory simulation results and CSV files. + + spike_trace: Raw spike trace data (list of lists of NeuronAddress) + potential_trace: Raw potential trace data (list of lists of floats) + perf_trace: Raw performance trace data (dict of metric lists) + message_trace: Raw message trace data (list of lists of message dicts) + neuron_labels: Optional mapping of (group, offset) to custom labels + + From simulation results + From CSV files + Convert to pandas + """ + + spike_trace: Optional[List[List[Any]]] = None + potential_trace: Optional[List[List[float]]] = None + perf_trace: Optional[Dict[str, List[Any]]] = None + message_trace: Optional[List[List[Dict[str, Any]]]] = None + neuron_labels: Dict[Tuple[str, int], str] = field(default_factory=dict) + + # Cache for converted data + _spike_df_cache: Optional[pd.DataFrame] = field(default=None, repr=False) + _potential_df_cache: Optional[pd.DataFrame] = field(default=None, repr=False) + _perf_df_cache: Optional[pd.DataFrame] = field(default=None, repr=False) + _message_df_cache: Optional[pd.DataFrame] = field(default=None, repr=False) + + @classmethod + def from_sim_results(cls, results: Dict[str, Any]) -> "TraceData": + """ + Create TraceData from chip.sim() return dictionary. + + Args: Dictionary returned by SpikingChip.sim(), containing keys like 'spike_trace', 'potential_trace', 'perf_trace', 'message_trace'. + + Returns: TraceData instance with loaded traces. + + results = chip.sim(100, spike_trace=True, perf_trace=True) + traces = TraceData.from_sim_results(results) + """ + return cls( + spike_trace=results.get("spike_trace"), + potential_trace=results.get("potential_trace"), + perf_trace=results.get("perf_trace"), + message_trace=results.get("message_trace"), + ) + + @classmethod + def from_files( + cls, + spike_csv: Optional[Union[str, Path]] = None, + potential_csv: Optional[Union[str, Path]] = None, + perf_csv: Optional[Union[str, Path]] = None, + message_csv: Optional[Union[str, Path]] = None, + ) -> "TraceData": + """ + Create TraceData from CSV trace files. + + spike_csv: Path to spike trace CSV file + potential_csv: Path to potential trace CSV file + perf_csv: Path to performance trace CSV file + message_csv: Path to message trace CSV file + + Returns: TraceData instance with loaded traces. + """ + instance = cls() + + if spike_csv is not None: + instance._load_spike_csv(Path(spike_csv)) + if potential_csv is not None: + instance._load_potential_csv(Path(potential_csv)) + if perf_csv is not None: + instance._load_perf_csv(Path(perf_csv)) + if message_csv is not None: + instance._load_message_csv(Path(message_csv)) + + return instance + + def _load_spike_csv(self, path: Path) -> None: + """Load spike trace from CSV file.""" + # SANA-FE spike CSV format: neuron,timestep + self.spike_trace = [] + max_timestep = 0 + + with open(path, "r") as f: + reader = csv.DictReader(f) + spikes_by_timestep: Dict[int, List[Any]] = {} + + for row in reader: + timestep = int(row["timestep"]) + neuron_str = row["neuron"] + group_name, offset_str = neuron_str.rsplit(".", 1) + + max_timestep = max(max_timestep, timestep) + + if timestep not in spikes_by_timestep: + spikes_by_timestep[timestep] = [] + + # Object mimicking NeuronAddress + spike_info = type("NeuronAddress", (), { + "group_name": group_name, + "neuron_offset": int(offset_str) + })() + spikes_by_timestep[timestep].append(spike_info) + + # Convert to list format (indexed by timestep) + self.spike_trace = [ + spikes_by_timestep.get(t, []) + for t in range(max_timestep + 1) + ] + + def _load_potential_csv(self, path: Path) -> None: + # Potential CSV format: columns are neuron IDs, rows are timesteps + df = pd.read_csv(path) + self.potential_trace = df.values.tolist() + + def _load_perf_csv(self, path: Path) -> None: + df = pd.read_csv(path) + self.perf_trace = df.to_dict(orient="list") + + def _load_message_csv(self, path: Path) -> None: + self.message_trace = [] + + with open(path, "r") as f: + reader = csv.DictReader(f) + messages_by_timestep: Dict[int, List[Dict]] = {} + max_timestep = 0 + + for row in reader: + timestep = int(row["timestep"]) + max_timestep = max(max_timestep, timestep) + + if timestep not in messages_by_timestep: + messages_by_timestep[timestep] = [] + + # Convert numeric fields + msg = {} + for key, value in row.items(): + try: + if "." in value or "e" in value.lower(): + msg[key] = float(value) + else: + msg[key] = int(value) + except (ValueError, AttributeError): + msg[key] = value + + messages_by_timestep[timestep].append(msg) + + self.message_trace = [ + messages_by_timestep.get(t, []) + for t in range(max_timestep + 1) + ] + + def has_spikes(self) -> bool: + return self.spike_trace is not None and len(self.spike_trace) > 0 + + def has_potentials(self) -> bool: + return self.potential_trace is not None and len(self.potential_trace) > 0 + + def has_performance(self) -> bool: + return self.perf_trace is not None and len(self.perf_trace) > 0 + + def has_messages(self) -> bool: + return self.message_trace is not None and len(self.message_trace) > 0 + + @property + def timesteps(self) -> int: + """ + Get the number of timesteps in the trace data. + + Returns the length based on whichever trace is available. + """ + if self.has_spikes(): + return len(self.spike_trace) + if self.has_potentials(): + return len(self.potential_trace) + if self.has_performance() and "timestep" in self.perf_trace: + return len(self.perf_trace["timestep"]) + if self.has_messages(): + return len(self.message_trace) + return 0 + + def spikes_to_dataframe(self) -> pd.DataFrame: + """ + Convert spike trace to pandas DataFrame. + + Returns: + DataFrame with columns: + - timestep: Simulation timestep + - group: Neuron group name + - neuron_offset: Index within group + - neuron_id: Full identifier (group.offset) + """ + if not self.has_spikes(): + raise ValueError("No spike trace data available") + + if self._spike_df_cache is not None: + return self._spike_df_cache.copy() + + records = [] + for timestep, spikes_at_t in enumerate(self.spike_trace): + for spike in spikes_at_t: + records.append({ + "timestep": timestep, + "group": spike.group_name, + "neuron_offset": spike.neuron_offset, + "neuron_id": f"{spike.group_name}.{spike.neuron_offset}", + }) + + self._spike_df_cache = pd.DataFrame(records) + return self._spike_df_cache.copy() + + def spikes_to_events(self) -> List[SpikeEvent]: + """ + Convert spike trace to list of SpikeEvent objects. + + Returns: List of SpikeEvent objects, sorted by timestep. + """ + if not self.has_spikes(): + raise ValueError("No spike trace data available") + + events = [] + for timestep, spikes_at_t in enumerate(self.spike_trace): + for spike in spikes_at_t: + events.append(SpikeEvent( + timestep=timestep, + group_name=spike.group_name, + neuron_offset=spike.neuron_offset, + )) + return events + + def spikes_to_matrix( + self, + neuron_ids: Optional[Sequence[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + """ + Convert spike trace to binary spike matrix. + + neuron_ids: Optional list of neuron IDs to include. If None, all neurons that spiked at least once are included. + + Returns: Tuple of 2D numpy array of shape (timesteps, neurons) with 1s at spikes and list of neuron ID strings corresponding to columns + """ + if not self.has_spikes(): + raise ValueError("No spike trace data available") + + # Collect all unique neuron IDs if not specified + if neuron_ids is None: + neuron_id_set = set() + for spikes_at_t in self.spike_trace: + for spike in spikes_at_t: + neuron_id_set.add(f"{spike.group_name}.{spike.neuron_offset}") + neuron_ids = sorted(neuron_id_set) + + neuron_ids = list(neuron_ids) + neuron_to_idx = {nid: idx for idx, nid in enumerate(neuron_ids)} + + n_timesteps = len(self.spike_trace) + n_neurons = len(neuron_ids) + + matrix = np.zeros((n_timesteps, n_neurons), dtype=np.int8) + + for timestep, spikes_at_t in enumerate(self.spike_trace): + for spike in spikes_at_t: + neuron_id = f"{spike.group_name}.{spike.neuron_offset}" + if neuron_id in neuron_to_idx: + matrix[timestep, neuron_to_idx[neuron_id]] = 1 + + return matrix, neuron_ids + + def potentials_to_dataframe( + self, + neuron_ids: Optional[Sequence[str]] = None, + ) -> pd.DataFrame: + """ + Convert potential trace to pandas DataFrame. + + neuron_ids: Optional list of neuron ID strings to use as column names. If None, columns are named numerically (0, 1, 2, ...). + + Returns: DataFrame with timestep as index and neurons as columns. + """ + if not self.has_potentials(): + raise ValueError("No potential trace data available") + + df = pd.DataFrame(self.potential_trace) + df.index.name = "timestep" + + if neuron_ids is not None: + if len(neuron_ids) != len(df.columns): + raise ValueError( + f"Number of neuron_ids ({len(neuron_ids)}) doesn't match " + f"number of traced neurons ({len(df.columns)})" + ) + df.columns = neuron_ids + + return df + + def potentials_to_array(self) -> np.ndarray: + """ + Convert potential trace to numpy array. + + Returns: 2D numpy array of shape (timesteps, neurons). + """ + if not self.has_potentials(): + raise ValueError("No potential trace data available") + + return np.array(self.potential_trace) + + def performance_to_dataframe(self) -> pd.DataFrame: + """ + Convert performance trace to pandas DataFrame. + + Returns: + DataFrame with columns for each performance metric: + - timestep, fired, updated, packets, hops, spikes + - sim_time, synapse_energy, dendrite_energy, soma_energy + - network_energy, total_energy + """ + if not self.has_performance(): + raise ValueError("No performance trace data available") + + if self._perf_df_cache is not None: + return self._perf_df_cache.copy() + + self._perf_df_cache = pd.DataFrame(self.perf_trace) + return self._perf_df_cache.copy() + + def messages_to_dataframe(self) -> pd.DataFrame: + """ + Convert message trace to pandas DataFrame. + + Returns: + DataFrame with columns for each message attribute: + - timestep, mid, src_neuron, src_hw, dest_hw + - hops, spikes, send_timestamp, received_timestamp + - processed_timestamp, generation_delay, processing_delay + - network_delay, blocking_delay, messages_along_route + """ + if not self.has_messages(): + raise ValueError("No message trace data available") + + if self._message_df_cache is not None: + return self._message_df_cache.copy() + + records = [] + for timestep, messages_at_t in enumerate(self.message_trace): + for msg in messages_at_t: + record = {"timestep": timestep} + record.update(msg) + records.append(record) + + self._message_df_cache = pd.DataFrame(records) + return self._message_df_cache.copy() + + def get_neuron_groups(self) -> List[str]: + """ + Get list of unique neuron group names from spike trace. + + Returns: Sorted list of group names. + """ + if not self.has_spikes(): + return [] + + groups = set() + for spikes_at_t in self.spike_trace: + for spike in spikes_at_t: + groups.add(spike.group_name) + + return sorted(groups) + + def filter_by_groups( + self, + groups: Sequence[str], + ) -> "TraceData": + """ + Create a new TraceData with only spikes from specified groups. + + groups: List of group names to include. + + Returns: New TraceData instance with filtered spike data and other trace types are copied as-is. + """ + groups_set = set(groups) + + filtered_spikes = None + if self.has_spikes(): + filtered_spikes = [] + for spikes_at_t in self.spike_trace: + filtered_at_t = [ + s for s in spikes_at_t + if s.group_name in groups_set + ] + filtered_spikes.append(filtered_at_t) + + return TraceData( + spike_trace=filtered_spikes, + potential_trace=self.potential_trace, + perf_trace=self.perf_trace, + message_trace=self.message_trace, + neuron_labels=self.neuron_labels.copy(), + ) + + def filter_by_time( + self, + start: Optional[int] = None, + end: Optional[int] = None, + ) -> "TraceData": + """ + Create a new TraceData filtered to a time range. + + start: Start timestep (inclusive). None means from beginning. + end: End timestep (exclusive). None means to end. + + Returns: New TraceData instance with filtered time range. + """ + start = start or 0 + end = end or self.timesteps + + filtered_spikes = None + if self.has_spikes(): + filtered_spikes = self.spike_trace[start:end] + + filtered_potentials = None + if self.has_potentials(): + filtered_potentials = self.potential_trace[start:end] + + filtered_messages = None + if self.has_messages(): + filtered_messages = self.message_trace[start:end] + + # Performance trace needs special handling (dict of lists) + filtered_perf = None + if self.has_performance(): + filtered_perf = {} + for key, values in self.perf_trace.items(): + filtered_perf[key] = values[start:end] + + return TraceData( + spike_trace=filtered_spikes, + potential_trace=filtered_potentials, + perf_trace=filtered_perf, + message_trace=filtered_messages, + neuron_labels=self.neuron_labels.copy(), + ) + + def __repr__(self) -> str: + parts = [f"TraceData(timesteps={self.timesteps}"] + if self.has_spikes(): + total_spikes = sum(len(s) for s in self.spike_trace) + parts.append(f"spikes={total_spikes}") + if self.has_potentials(): + n_neurons = len(self.potential_trace[0]) if self.potential_trace else 0 + parts.append(f"potential_neurons={n_neurons}") + if self.has_performance(): + parts.append("perf=True") + if self.has_messages(): + total_msgs = sum(len(m) for m in self.message_trace) + parts.append(f"messages={total_msgs}") + return ", ".join(parts) + ")" diff --git a/sanafe/viz/__init__.py b/sanafe/viz/__init__.py new file mode 100644 index 0000000..9ec109d --- /dev/null +++ b/sanafe/viz/__init__.py @@ -0,0 +1,41 @@ +""" +Provides plotting utilities for visualizing SANA-FE simulation outputs, +including spike raster plots, potential timeseries, and performance metrics. +""" + +from sanafe.viz.raster import raster_plot +from sanafe.viz.potential import potential_plot +from sanafe.viz.performance import ( + energy_breakdown_plot, + throughput_plot, + latency_histogram, + latency_comparison, +) +from sanafe.viz.styles import ( + SANAFEStyle, + get_group_colors, + apply_style, + set_default_style, + PUBLICATION_STYLE, + PRESENTATION_STYLE, + NOTEBOOK_STYLE, +) + +__all__ = [ + # SNN visualization + "raster_plot", + "potential_plot", + # Performance visualization + "energy_breakdown_plot", + "throughput_plot", + "latency_histogram", + "latency_comparison", + # Styling + "SANAFEStyle", + "get_group_colors", + "apply_style", + "set_default_style", + "PUBLICATION_STYLE", + "PRESENTATION_STYLE", + "NOTEBOOK_STYLE", +] diff --git a/sanafe/viz/performance.py b/sanafe/viz/performance.py new file mode 100644 index 0000000..cba89f9 --- /dev/null +++ b/sanafe/viz/performance.py @@ -0,0 +1,733 @@ +""" +This module provides functions for visualizing hardware performance metrics +and message-level latency data from SANA-FE simulations. Includes energy +breakdown plots, throughput timeseries, and latency distribution histograms. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import pandas as pd + +from sanafe.data.traces import TraceData +from sanafe.viz.styles import ( + SANAFEStyle, + DEFAULT_COLORS, + create_figure, + get_default_style, + style_axis, +) + +_ENERGY_COLUMNS = [ + "synapse_energy", + "dendrite_energy", + "soma_energy", + "network_energy", +] + +_LATENCY_COLUMNS = [ + "generation_delay", + "receive_delay", + "network_delay", + "blocked_delay", +] + + +def _coerce_to_trace_data(data: Union[TraceData, Dict[str, Any]]) -> TraceData: + """Accept raw results dict or TraceData, always return TraceData.""" + if isinstance(data, dict): + return TraceData.from_sim_results(data) + return data + + +def _get_perf_df(data: TraceData) -> pd.DataFrame: + """Extract perf DataFrame or raise a clear error.""" + if not data.has_performance(): + raise ValueError( + "No performance trace data available. " + "Pass perf_trace to chip.sim() or load a perf CSV." + ) + return data.performance_to_dataframe() + + +def _get_message_df(data: TraceData) -> pd.DataFrame: + """Extract message DataFrame or raise a clear error.""" + if not data.has_messages(): + raise ValueError( + "No message trace data available. " + "Pass message_trace to chip.sim() or load a message CSV." + ) + return data.messages_to_dataframe() + + +def _auto_energy_scale( + values: np.ndarray, +) -> Tuple[np.ndarray, str]: + """ + Pick a human-readable energy unit and return scaled values + unit label. + + Chooses the unit so that peak values fall in roughly 0.1 - 999 range. + """ + peak = np.nanmax(values) if len(values) > 0 else 0.0 + if peak == 0.0: + return values, "J" + + scales = [ + (1e-15, "fJ"), + (1e-12, "pJ"), + (1e-9, "nJ"), + (1e-6, "µJ"), + (1e-3, "mJ"), + (1.0, "J"), + ] + for factor, label in scales: + if peak / factor < 1000: + return values / factor, label + + return values, "J" + + +def _auto_time_scale( + values: np.ndarray, +) -> Tuple[np.ndarray, str]: + """Pick a human-readable time unit for delay values.""" + peak = np.nanmax(np.abs(values)) if len(values) > 0 else 0.0 + if peak == 0.0: + return values, "s" + + scales = [ + (1e-12, "ps"), + (1e-9, "ns"), + (1e-6, "µs"), + (1e-3, "ms"), + (1.0, "s"), + ] + for factor, label in scales: + if peak / factor < 1000: + return values / factor, label + + return values, "s" + + +def energy_breakdown_plot( + data: Union[TraceData, Dict[str, Any]], + time_range: Optional[Tuple[int, int]] = None, + mode: str = "stacked_area", + components: Optional[Sequence[str]] = None, + normalize: bool = False, + show_total: bool = False, + colors: Optional[Sequence[str]] = None, + component_labels: Optional[Sequence[str]] = None, + show_legend: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: Optional[str] = None, + **kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Plot energy consumption breakdown over time from the performance trace. + + Visualizes per-component energy (synapse, dendrite, soma, network) at each + timestep. Supports stacked area, stacked bar, and grouped bar modes. + + Args: + data: TraceData or raw results dict from chip.sim(). + time_range: (start, end) timestep range to display. + mode: Plot mode - one of: + - "stacked_area": filled area chart with components stacked + - "stacked_bar": stacked bar chart per timestep + - "bar": grouped (side-by-side) bar chart + components: Which energy columns to include. Defaults to + ["synapse_energy", "dendrite_energy", "soma_energy", + "network_energy"]. Can also be ["total_energy"] for a single line. + normalize: If True, show each timestep as a percentage breakdown + (100% stacked). Only applies to stacked modes. + show_total: If True, overlay the total_energy line on top of + the stacked plot. + colors: List of colors for each component. If None, uses + style.energy_colors. + component_labels: Display names for each component. + show_legend: Whether to show a legend. + ax: Existing Axes to plot on. + style: Style configuration. + figsize: Figure size (width, height). + title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. Auto-generated with unit if None. + **kwargs: Extra arguments forwarded to the underlying matplotlib call + (fill_between for area, bar for bar chart). + + Returns: Tuple of (Figure, Axes). + """ + data = _coerce_to_trace_data(data) + df = _get_perf_df(data) + + if style is None: + style = get_default_style() + + if components is None: + components = list(_ENERGY_COLUMNS) + missing = [c for c in components if c not in df.columns] + if missing: + raise ValueError( + f"Columns not found in perf trace: {missing}. " + f"Available: {list(df.columns)}" + ) + + if time_range is not None: + start, end = time_range + if "timestep" in df.columns: + df = df[(df["timestep"] >= start) & (df["timestep"] < end)] + else: + df = df.iloc[start:end] + + if "timestep" in df.columns: + timesteps = df["timestep"].values + else: + timesteps = np.arange(len(df)) + + # Pull raw energy arrays and compute auto scale + raw_arrays = [df[c].values.astype(float) for c in components] + + # Find a common unit from the total of all selected components + combined_peak = np.stack(raw_arrays).sum(axis=0) + _, unit = _auto_energy_scale(combined_peak) + factor = {"fJ": 1e-15, "pJ": 1e-12, "nJ": 1e-9, + "µJ": 1e-6, "mJ": 1e-3, "J": 1.0}.get(unit, 1.0) + scaled = [arr / factor for arr in raw_arrays] + + if normalize: + totals = sum(scaled) + with np.errstate(invalid="ignore", divide="ignore"): + scaled = [ + np.where(totals > 0, arr / totals * 100.0, 0.0) + for arr in scaled + ] + unit = "%" + + # Resolve colors and labels + if colors is None: + colors = style.energy_colors[: len(components)] + if component_labels is None: + component_labels = ( + style.energy_component_names[: len(components)] + if len(components) <= len(style.energy_component_names) + else [c.replace("_energy", "").replace("_", " ").title() + for c in components] + ) + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + if mode == "stacked_area": + cumulative = np.zeros_like(scaled[0]) + for i, (vals, label, color) in enumerate( + zip(scaled, component_labels, colors) + ): + ax.fill_between( + timesteps, cumulative, cumulative + vals, + label=label, color=color, + alpha=style.perf_fill_alpha + 0.4, + **kwargs, + ) + cumulative = cumulative + vals + + if show_total and not normalize: + total_scaled = df["total_energy"].values.astype(float) / factor + ax.plot( + timesteps, total_scaled, + color="black", linewidth=style.perf_line_width, + linestyle="--", label="Total", zorder=5, + ) + + elif mode == "stacked_bar": + bar_width = 0.8 + bottoms = np.zeros(len(timesteps)) + for vals, label, color in zip(scaled, component_labels, colors): + ax.bar( + timesteps, vals, bottom=bottoms, + width=bar_width, label=label, color=color, + edgecolor=style.hist_edgecolor, + linewidth=style.hist_edgewidth, + **kwargs, + ) + bottoms += vals + + elif mode == "bar": + n_comp = len(components) + bar_width = 0.8 / n_comp + offsets = np.linspace( + -(n_comp - 1) * bar_width / 2, + (n_comp - 1) * bar_width / 2, + n_comp, + ) + for offset, vals, label, color in zip( + offsets, scaled, component_labels, colors + ): + ax.bar( + timesteps + offset, vals, + width=bar_width, label=label, color=color, + edgecolor=style.hist_edgecolor, + linewidth=style.hist_edgewidth, + **kwargs, + ) + else: + raise ValueError( + f"Unknown mode '{mode}'. Use 'stacked_area', 'stacked_bar', or 'bar'." + ) + + if ylabel is None: + ylabel = f"Energy ({unit})" if not normalize else "Energy (%)" + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if show_legend: + ax.legend(loc="upper right", framealpha=0.9) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + +def throughput_plot( + data: Union[TraceData, Dict[str, Any]], + metrics: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + colors: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + show_legend: bool = True, + secondary_y: Optional[Sequence[str]] = None, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: Optional[str] = None, + **plot_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Plot throughput / activity metrics over time from the performance trace. + + Renders one or more per-timestep performance counters as line plots. + Useful for visualizing how firing rate, packet count, and spike volume + evolve across the simulation. + + Args: + data: TraceData or raw results dict from chip.sim(). + metrics: List of column names to plot. Defaults to + ["fired", "packets", "spikes"]. + time_range: (start, end) timestep range to display. + colors: List of colors for each metric. + labels: Display names for each metric. If None, derived from column + names (underscores replaced, title-cased). + show_legend: Whether to show a legend. + secondary_y: List of metric names that should use a secondary y-axis + on the right. Useful when combining counts (fired) with energy + (total_energy) on one plot. + ax: Existing Axes to plot on. + style: Style configuration. + figsize: Figure size (width, height). + title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. + **plot_kwargs: Extra arguments forwarded to ax.plot(). + + Returns: + Tuple of (Figure, Axes). When secondary_y is used, the secondary + Axes is accessible via ``ax.right_ax`` (stored as an attribute). + """ + data = _coerce_to_trace_data(data) + df = _get_perf_df(data) + + if style is None: + style = get_default_style() + + if metrics is None: + metrics = ["fired", "packets", "spikes"] + missing = [m for m in metrics if m not in df.columns] + if missing: + raise ValueError( + f"Columns not found in perf trace: {missing}. " + f"Available: {list(df.columns)}" + ) + + if secondary_y is None: + secondary_y = [] + secondary_y_set = set(secondary_y) + + if time_range is not None: + start, end = time_range + if "timestep" in df.columns: + df = df[(df["timestep"] >= start) & (df["timestep"] < end)] + else: + df = df.iloc[start:end] + + if "timestep" in df.columns: + timesteps = df["timestep"].values + else: + timesteps = np.arange(len(df)) + + if colors is None: + colors = [DEFAULT_COLORS[i % len(DEFAULT_COLORS)] + for i in range(len(metrics))] + if labels is None: + labels = [m.replace("_", " ").title() for m in metrics] + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + ax_right = None + if secondary_y_set: + ax_right = ax.twinx() + ax.right_ax = ax_right # stash for user access + + plot_defaults = { + "linewidth": style.perf_line_width, + } + if style.perf_marker: + plot_defaults["marker"] = style.perf_marker + plot_defaults["markersize"] = style.perf_marker_size + plot_defaults.update(plot_kwargs) + + right_lines = [] + left_lines = [] + + for metric, color, label in zip(metrics, colors, labels): + values = df[metric].values.astype(float) + target_ax = ax_right if metric in secondary_y_set else ax + + # Auto-scale energy columns + scaled_label = label + if "energy" in metric: + values, unit = _auto_energy_scale(values) + scaled_label = f"{label} ({unit})" + elif "sim_time" in metric: + values, unit = _auto_time_scale(values) + scaled_label = f"{label} ({unit})" + + line, = target_ax.plot( + timesteps, values, + color=color, label=scaled_label, + **plot_defaults, + ) + if metric in secondary_y_set: + right_lines.append(line) + else: + left_lines.append(line) + + if ylabel is None: + ylabel = "Count" + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if ax_right is not None: + ax_right.tick_params(axis="y", labelsize=style.tick_size) + # Build unified legend from both axes + if show_legend: + all_lines = left_lines + right_lines + all_labels = [l.get_label() for l in all_lines] + ax.legend(all_lines, all_labels, loc="upper right", framealpha=0.9) + elif show_legend: + ax.legend(loc="upper right", framealpha=0.9) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + + +def latency_histogram( + data: Union[TraceData, Dict[str, Any]], + metric: str = "processing_delay", + filter_placeholder: bool = True, + time_range: Optional[Tuple[int, int]] = None, + bins: Optional[int] = None, + log_scale: bool = False, + color: Optional[str] = None, + show_stats: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: str = "Count", + **hist_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Plot a histogram of message-level latency or delay values. + + Draws a distribution of a chosen delay metric from the message trace. + Useful for understanding timing characteristics and identifying + congestion-related outliers. + + Args: + data: TraceData or raw results dict from chip.sim(). + metric: Which delay column to histogram. One of: + - "generation_delay" : time from soma fire to message send + - "processing_delay" : time from receive to processed + - "network_delay" : time in the NoC (send to receive) + - "blocking_delay" : extra delay due to congestion + - "hops" : number of NoC hops (integer) + Or any other numeric column in the message trace. + filter_placeholder: If True (default), exclude placeholder messages + (mid == -1) that represent timesteps where no real packet was sent. + time_range: (start, end) timestep range to include. + bins: Number of histogram bins. If None, uses style.hist_bins. + log_scale: If True, use a log scale on the y-axis. + color: Bar color. If None, uses the first default color. + show_stats: If True, annotate the plot with mean, median, and std. + ax: Existing Axes to plot on. + style: Style configuration. + figsize: Figure size (width, height). + title: Plot title. If None, auto-generated from metric name. + xlabel: X-axis label. If None, auto-generated from metric name. + ylabel: Y-axis label. + **hist_kwargs: Extra arguments forwarded to ax.hist(). + + Returns: Tuple of (Figure, Axes). + + """ + data = _coerce_to_trace_data(data) + df = _get_message_df(data) + + if style is None: + style = get_default_style() + + if metric not in df.columns: + raise ValueError( + f"Column '{metric}' not found in message trace. " + f"Available: {list(df.columns)}" + ) + + if filter_placeholder and "mid" in df.columns: + df = df[df["mid"] >= 0] + + if time_range is not None: + start, end = time_range + df = df[(df["timestep"] >= start) & (df["timestep"] < end)] + + values = df[metric].values.astype(float) + + mask = np.isfinite(values) + values = values[mask] + + if len(values) == 0: + # Nothing to plot — still create the figure and return it + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + ax.text( + 0.5, 0.5, "No data", + ha="center", va="center", transform=ax.transAxes, + fontsize=style.label_size, color="#999999", + ) + style_axis(ax, style, xlabel=xlabel or metric, ylabel=ylabel, + title=title) + return fig, ax + + # Auto-scale time-based metrics + is_time_metric = "delay" in metric or "latency" in metric or "timestamp" in metric + unit = "" + if is_time_metric: + values, unit = _auto_time_scale(values) + + if bins is None: + bins = style.hist_bins + if color is None: + color = DEFAULT_COLORS[0] + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + hist_defaults = { + "bins": bins, + "color": color, + "alpha": style.hist_alpha, + "edgecolor": style.hist_edgecolor, + "linewidth": style.hist_edgewidth, + } + hist_defaults.update(hist_kwargs) + ax.hist(values, **hist_defaults) + + if log_scale: + ax.set_yscale("log") + + if show_stats and len(values) > 0: + mean_val = np.mean(values) + median_val = np.median(values) + std_val = np.std(values) + stat_text = ( + f"mean = {mean_val:.3g}\n" + f"median = {median_val:.3g}\n" + f"std = {std_val:.3g}\n" + f"n = {len(values)}" + ) + ax.text( + 0.97, 0.95, stat_text, + transform=ax.transAxes, + ha="right", va="top", + fontsize=style.tick_size, + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", + edgecolor="#cccccc", alpha=0.9), + ) + + if xlabel is None: + pretty = metric.replace("_", " ").title() + xlabel = f"{pretty} ({unit})" if unit else pretty + if title is None: + title = f"Distribution of {metric.replace('_', ' ').title()}" + + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + +def latency_comparison( + data: Union[TraceData, Dict[str, Any]], + metrics: Optional[Sequence[str]] = None, + filter_placeholder: bool = True, + time_range: Optional[Tuple[int, int]] = None, + bins: Optional[int] = None, + colors: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + log_scale: bool = False, + show_legend: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: str = "Delay Component Comparison", + xlabel: Optional[str] = None, + ylabel: str = "Count", + **hist_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Overlay multiple delay distributions on the same axes. + + Plots overlapping histograms for several delay components so their + distributions can be visually compared. + + Args: + data: TraceData or raw results dict from chip.sim(). + metrics: Delay columns to compare. Defaults to + ["generation_delay", "processing_delay", "network_delay", + "blocking_delay"]. + filter_placeholder: Exclude placeholder messages (mid == -1). + time_range: (start, end) timestep range. + bins: Number of bins for all histograms. + colors: Color per metric. + labels: Display names per metric. + log_scale: Log y-axis. + show_legend: Show legend. + ax: Existing Axes. + style: Style configuration. + figsize: Figure size. + title: Plot title. + xlabel: X-axis label. Auto-scaled with time unit if None. + ylabel: Y-axis label. + **hist_kwargs: Extra arguments forwarded to ax.hist(). + + Returns: Tuple of (Figure, Axes). + + """ + data = _coerce_to_trace_data(data) + df = _get_message_df(data) + + if style is None: + style = get_default_style() + + if metrics is None: + metrics = list(_LATENCY_COLUMNS) + missing = [m for m in metrics if m not in df.columns] + if missing: + raise ValueError( + f"Columns not found in message trace: {missing}. " + f"Available: {list(df.columns)}" + ) + + if filter_placeholder and "mid" in df.columns: + df = df[df["mid"] >= 0] + + if time_range is not None: + start, end = time_range + df = df[(df["timestep"] >= start) & (df["timestep"] < end)] + + # Collect all finite values across metrics to find a common time unit + all_values = np.concatenate([ + df[m].values.astype(float) for m in metrics + ]) + all_values = all_values[np.isfinite(all_values)] + if len(all_values) == 0: + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + ax.text(0.5, 0.5, "No data", ha="center", va="center", + transform=ax.transAxes, fontsize=style.label_size, color="#999") + return fig, ax + + _, unit = _auto_time_scale(all_values) + factor = {"ps": 1e-12, "ns": 1e-9, "µs": 1e-6, + "ms": 1e-3, "s": 1.0}.get(unit, 1.0) + + # Resolve colors and labels + if colors is None: + colors = [DEFAULT_COLORS[i % len(DEFAULT_COLORS)] + for i in range(len(metrics))] + if labels is None: + labels = [m.replace("_", " ").title() for m in metrics] + + if bins is None: + bins = style.hist_bins + + # Create figure + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + hist_defaults = { + "bins": bins, + "alpha": style.hist_alpha * 0.85, + "edgecolor": style.hist_edgecolor, + "linewidth": style.hist_edgewidth, + } + hist_defaults.update(hist_kwargs) + + for metric, color, label in zip(metrics, colors, labels): + vals = df[metric].values.astype(float) + vals = vals[np.isfinite(vals)] / factor + if len(vals) > 0: + ax.hist(vals, color=color, label=label, **hist_defaults) + + if log_scale: + ax.set_yscale("log") + + if xlabel is None: + xlabel = f"Delay ({unit})" + + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if show_legend: + ax.legend(loc="upper right", framealpha=0.9) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax diff --git a/sanafe/viz/potential.py b/sanafe/viz/potential.py new file mode 100644 index 0000000..5d2ee7e --- /dev/null +++ b/sanafe/viz/potential.py @@ -0,0 +1,370 @@ +""" +This module provides functions for creating potential timeseries plots, +which display membrane voltage traces for individual neurons over time. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np + +from sanafe.data.traces import TraceData +from sanafe.viz.styles import ( + SANAFEStyle, + create_figure, + get_default_style, + style_axis, + DEFAULT_COLORS, +) + + +def potential_plot( + data: Union[TraceData, Dict[str, Any]], + neuron_ids: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + colors: Optional[Sequence[str]] = None, + show_threshold: Optional[float] = None, + threshold_color: str = "#d62728", + threshold_linestyle: str = "--", + show_legend: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: str = "Membrane Potential", + **plot_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Create a membrane potential timeseries plot from SANA-FE trace data. + Displays membrane voltage traces for neurons over time. Multiple neurons + can be shown on the same plot with different colors. + + Args: + data: Either a TraceData object or raw results dict from chip.sim() + neuron_ids: List of neuron ID strings for legend labels. If None, neurons are labeled numerically (Neuron 0, Neuron 1, ...). + time_range: Tuple of (start, end) timesteps to display. If None, shows all. + colors: List of colors for each neuron trace. If None, auto-assigned. + show_threshold: If set, draws a horizontal threshold line at this value. + threshold_color: Color for the threshold line. + threshold_linestyle: Line style for the threshold line. + show_legend: Whether to show a legend. + ax: Existing Axes to plot on. If None, creates new figure. + style: Style configuration. If None, uses default style. + figsize: Figure size (width, height). If None, uses style default. + title: Plot title. If None, no title is shown. + xlabel: X-axis label. + ylabel: Y-axis label. + **plot_kwargs: Additional arguments passed to plt.plot(). + + Returns: Tuple of (Figure, Axes) objects. + """ + # Handle raw results dict + if isinstance(data, dict): + data = TraceData.from_sim_results(data) + + if not data.has_potentials(): + raise ValueError("No potential trace data available in the provided data") + + if style is None: + style = get_default_style() + + # Get potential data as array + potentials = data.potentials_to_array() + n_timesteps, n_neurons = potentials.shape + + # Apply time range filter + time_offset = 0 + if time_range is not None: + start, end = time_range + potentials = potentials[start:end, :] + n_timesteps = potentials.shape[0] + time_offset = start + + # Generate default neuron labels + if neuron_ids is None: + neuron_ids = [f"Neuron {i}" for i in range(n_neurons)] + elif len(neuron_ids) != n_neurons: + raise ValueError( + f"Number of neuron_ids ({len(neuron_ids)}) doesn't match " + f"number of traced neurons ({n_neurons})" + ) + + if colors is None: + colors = DEFAULT_COLORS[:n_neurons] + if len(colors) < n_neurons: + # Extend with cycling if not enough colors + colors = [DEFAULT_COLORS[i % len(DEFAULT_COLORS)] for i in range(n_neurons)] + + # Create figure if needed + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + # Time axis + timesteps = np.arange(time_offset, time_offset + n_timesteps) + + plot_defaults = { + "linewidth": style.potential_line_width, + } + if style.potential_marker: + plot_defaults["marker"] = style.potential_marker + plot_defaults["markersize"] = style.potential_marker_size + plot_defaults.update(plot_kwargs) + + # Plot each neuron's potential trace + lines = [] + for i in range(n_neurons): + line, = ax.plot( + timesteps, + potentials[:, i], + color=colors[i], + label=neuron_ids[i], + **plot_defaults, + ) + lines.append(line) + + # Add threshold line if specified + if show_threshold is not None: + ax.axhline( + y=show_threshold, + color=threshold_color, + linestyle=threshold_linestyle, + linewidth=style.line_width * 0.8, + label="Threshold", + zorder=0, # Behind the traces + ) + + ax.set_xlim(timesteps[0] - 0.5, timesteps[-1] + 0.5) + + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + # Legend + if show_legend: + ax.legend(loc="upper right", framealpha=0.9) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + +def potential_heatmap( + data: Union[TraceData, Dict[str, Any], np.ndarray], + neuron_ids: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + cmap: str = "viridis", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + show_colorbar: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: str = "Neuron", + **imshow_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Heatmap visualization of membrane potentials over time. + + Args: + data: TraceData, results dict, or numpy array of potentials + neuron_ids: List of neuron ID strings for y-axis labels + time_range: Tuple of (start, end) timesteps to display + cmap: Colormap name for the heatmap + vmin: Minimum value for color scaling + vmax: Maximum value for color scaling + show_colorbar: Whether to show a colorbar + ax: Existing Axes to plot on + style: Style configuration + figsize: Figure size (width, height) + title: Plot title + xlabel: X-axis label + ylabel: Y-axis label + **imshow_kwargs: Additional arguments passed to plt.imshow() + + Returns: Tuple of (Figure, Axes) objects. + """ + # Handle different input types + if isinstance(data, dict): + data = TraceData.from_sim_results(data) + + if isinstance(data, TraceData): + if not data.has_potentials(): + raise ValueError("No potential trace data available") + potentials = data.potentials_to_array() + else: + potentials = np.asarray(data) + + if style is None: + style = get_default_style() + + n_timesteps, n_neurons = potentials.shape + + # Apply time range filter + time_offset = 0 + if time_range is not None: + start, end = time_range + potentials = potentials[start:end, :] + n_timesteps = potentials.shape[0] + time_offset = start + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + # Create heatmap (transpose so neurons are on y-axis) + imshow_defaults = { + "aspect": "auto", + "origin": "lower", + "cmap": cmap, + "extent": [time_offset - 0.5, time_offset + n_timesteps - 0.5, + -0.5, n_neurons - 0.5], + } + if vmin is not None: + imshow_defaults["vmin"] = vmin + if vmax is not None: + imshow_defaults["vmax"] = vmax + imshow_defaults.update(imshow_kwargs) + + im = ax.imshow(potentials.T, **imshow_defaults) + + if show_colorbar: + cbar = fig.colorbar(im, ax=ax) + cbar.set_label("Membrane Potential") + + # Configure y-axis labels + if neuron_ids is not None: + if n_neurons <= 20: + ax.set_yticks(range(n_neurons)) + ax.set_yticklabels(neuron_ids) + else: + # Show subset of labels + step = n_neurons // 10 + tick_positions = list(range(0, n_neurons, step)) + tick_labels = [neuron_ids[i] for i in tick_positions] + ax.set_yticks(tick_positions) + ax.set_yticklabels(tick_labels) + + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + +def potential_subplots( + data: Union[TraceData, Dict[str, Any]], + neuron_ids: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + ncols: int = 1, + colors: Optional[Sequence[str]] = None, + show_threshold: Optional[float] = None, + sharex: bool = True, + sharey: bool = True, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + **plot_kwargs, +) -> Tuple[plt.Figure, np.ndarray]: + """ + Grid of subplots, one per neuron potential trace. + + Args: + data: TraceData or results dict with potential traces + neuron_ids: List of neuron ID strings for subplot titles + time_range: Tuple of (start, end) timesteps to display + ncols: Number of columns in the subplot grid + colors: List of colors for each subplot + show_threshold: If set, draws threshold line on each subplot + sharex: Whether subplots share x-axis + sharey: Whether subplots share y-axis + style: Style configuration + figsize: Figure size (width, height) + **plot_kwargs: Additional arguments passed to plt.plot() + + Returns: + Tuple of (Figure, array of Axes) objects. + """ + # Handle raw results dict + if isinstance(data, dict): + data = TraceData.from_sim_results(data) + + if not data.has_potentials(): + raise ValueError("No potential trace data available") + + if style is None: + style = get_default_style() + + # Get potential data + potentials = data.potentials_to_array() + n_timesteps, n_neurons = potentials.shape + + time_offset = 0 + if time_range is not None: + start, end = time_range + potentials = potentials[start:end, :] + n_timesteps = potentials.shape[0] + time_offset = start + + # Generate default neuron labels + if neuron_ids is None: + neuron_ids = [f"Neuron {i}" for i in range(n_neurons)] + + if colors is None: + colors = [DEFAULT_COLORS[i % len(DEFAULT_COLORS)] for i in range(n_neurons)] + + # Calculate grid dimensions + nrows = (n_neurons + ncols - 1) // ncols + + # Determine figure size + if figsize is None: + width = style.figure_size[0] * min(ncols, 2) + height = style.figure_size[1] * 0.6 * nrows + figsize = (width, height) + + # Create subplots + fig, axes = plt.subplots( + nrows, ncols, + figsize=figsize, + sharex=sharex, + sharey=sharey, + squeeze=False, + ) + axes = axes.flatten() + + timesteps = np.arange(time_offset, time_offset + n_timesteps) + + # Plot each neuron + plot_defaults = { + "linewidth": style.potential_line_width, + } + plot_defaults.update(plot_kwargs) + + for i in range(n_neurons): + ax = axes[i] + ax.plot(timesteps, potentials[:, i], color=colors[i], **plot_defaults) + ax.set_title(neuron_ids[i], fontsize=style.label_size) + + if show_threshold is not None: + ax.axhline(y=show_threshold, color="#d62728", linestyle="--", + linewidth=style.line_width * 0.6, alpha=0.7) + + # Add labels to edge subplots only + if i >= n_neurons - ncols: + ax.set_xlabel("Time-step", fontsize=style.tick_size) + if i % ncols == 0: + ax.set_ylabel("Potential", fontsize=style.tick_size) + + # Hide unused subplots + for i in range(n_neurons, len(axes)): + axes[i].set_visible(False) + + fig.tight_layout() + + return fig, axes[:n_neurons] diff --git a/sanafe/viz/raster.py b/sanafe/viz/raster.py new file mode 100644 index 0000000..c631b79 --- /dev/null +++ b/sanafe/viz/raster.py @@ -0,0 +1,267 @@ +""" +raster.py - Spike raster plot visualization for SANA-FE. + +This module provides functions for creating spike raster plots, which display +neural spiking activity over time with neurons on the y-axis and timesteps +on the x-axis. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np + +from sanafe.data.traces import TraceData +from sanafe.viz.styles import ( + SANAFEStyle, + create_figure, + get_default_style, + get_group_colors, + style_axis, +) + + +def raster_plot( + data: Union[TraceData, Dict[str, Any]], + groups: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + colors: Optional[Dict[str, str]] = None, + group_spacing: float = 0.5, + show_group_labels: bool = True, + show_legend: bool = True, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: str = "Neuron", + **scatter_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Create a spike raster plot from SANA-FE trace data. + + Displays neural spiking activity with neurons on the y-axis and timesteps + on the x-axis. Each spike is shown as a vertical tick mark. Neurons can + be colored by group. + + Args: + data: Either a TraceData object or raw results dict from chip.sim() + groups: List of group names to include. If None, includes all groups. + time_range: Tuple of (start, end) timesteps to display. If None, shows all. + colors: Dict mapping group names to color strings. If None, auto-assigned. + group_spacing: Vertical spacing between neuron groups (in neuron units) + show_group_labels: Whether to show group names on y-axis + show_legend: Whether to show a legend for group colors + ax: Existing Axes to plot on. If None, creates new figure. + style: Style configuration. If None, uses default style. + figsize: Figure size (width, height). If None, uses style default. + title: Plot title. If None, no title is shown. + xlabel: X-axis label + ylabel: Y-axis label + **scatter_kwargs: Additional arguments passed to plt.scatter() + + Returns: Tuple of (Figure, Axes) objects. + + """ + if isinstance(data, dict): + data = TraceData.from_sim_results(data) + + if not data.has_spikes(): + raise ValueError("No spike trace data available in the provided data") + + if style is None: + style = get_default_style() + + all_groups = data.get_neuron_groups() + if groups is not None: + # Validate groups exist + invalid = set(groups) - set(all_groups) + if invalid: + raise ValueError(f"Unknown groups: {invalid}. Available: {all_groups}") + groups = list(groups) + else: + groups = all_groups + + if time_range is not None: + data = data.filter_by_time(start=time_range[0], end=time_range[1]) + + if colors is None: + colors = get_group_colors(groups, style) + + # Build neuron ordering: map (group, offset) -> y-position + neurons_by_group: Dict[str, set] = {g: set() for g in groups} + for spikes_at_t in data.spike_trace: + for spike in spikes_at_t: + if spike.group_name in neurons_by_group: + neurons_by_group[spike.group_name].add(spike.neuron_offset) + + # Create y-position mapping with group spacing + neuron_to_y: Dict[Tuple[str, int], float] = {} + y_ticks: List[float] = [] + y_tick_labels: List[str] = [] + current_y = 0.0 + group_y_ranges: Dict[str, Tuple[float, float]] = {} + + for group in groups: + offsets = sorted(neurons_by_group[group]) + if not offsets: + continue + + group_start_y = current_y + for offset in offsets: + neuron_to_y[(group, offset)] = current_y + current_y += 1.0 + group_end_y = current_y - 1.0 + + group_y_ranges[group] = (group_start_y, group_end_y) + + group_center = (group_start_y + group_end_y) / 2 + y_ticks.append(group_center) + y_tick_labels.append(group) + + current_y += group_spacing + + # Prepare scatter data + timesteps = [] + y_positions = [] + spike_colors = [] + + for timestep, spikes_at_t in enumerate(data.spike_trace): + for spike in spikes_at_t: + key = (spike.group_name, spike.neuron_offset) + if key in neuron_to_y: + timesteps.append(timestep) + y_positions.append(neuron_to_y[key]) + spike_colors.append(colors.get(spike.group_name, "#333333")) + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + scatter_defaults = { + "marker": style.raster_marker, + "s": style.raster_marker_size, + "linewidths": style.raster_line_width, + } + scatter_defaults.update(scatter_kwargs) + + # Plot spikes + if timesteps: + ax.scatter(timesteps, y_positions, c=spike_colors, **scatter_defaults) + + if show_group_labels and y_ticks: + ax.set_yticks(y_ticks) + ax.set_yticklabels(y_tick_labels) + else: + ax.set_ylabel(ylabel) + + time_max = data.timesteps + y_max = current_y - group_spacing # Remove last spacing + + ax.set_xlim(-0.5, time_max + 0.5) + ax.set_ylim(-0.5, y_max + 0.5) + + style_axis(ax, style, xlabel=xlabel, title=title) + + if show_legend and len(groups) > 1: + from matplotlib.lines import Line2D + handles = [ + Line2D([0], [0], marker=style.raster_marker, color="w", + markerfacecolor=colors.get(g, "#333333"), + markeredgecolor=colors.get(g, "#333333"), + markersize=10, label=g) + for g in groups if g in group_y_ranges + ] + ax.legend(handles=handles, loc="upper right", framealpha=0.9) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax + + +def raster_plot_matrix( + spike_matrix: np.ndarray, + neuron_ids: Optional[Sequence[str]] = None, + time_range: Optional[Tuple[int, int]] = None, + ax: Optional[plt.Axes] = None, + style: Optional[SANAFEStyle] = None, + figsize: Optional[Tuple[float, float]] = None, + title: Optional[str] = None, + xlabel: str = "Time-step", + ylabel: str = "Neuron", + color: str = "#1f77b4", + **scatter_kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Create a spike raster plot from a binary spike matrix. + + Alternative to raster_plot() when you have spike data as a numpy array + rather than TraceData. + + Args: + spike_matrix: 2D binary array of shape (timesteps, neurons) where + 1 indicates a spike occurred + neuron_ids: Optional list of neuron ID strings for y-axis labels + time_range: Tuple of (start, end) timesteps to display + ax: Existing Axes to plot on + style: Style configuration + figsize: Figure size (width, height) + title: Plot title + xlabel: X-axis label + ylabel: Y-axis label + color: Color for all spikes + **scatter_kwargs: Additional arguments passed to plt.scatter() + + Returns: Tuple of (Figure, Axes) objects. + """ + if style is None: + style = get_default_style() + + if time_range is not None: + spike_matrix = spike_matrix[time_range[0]:time_range[1], :] + + n_timesteps, n_neurons = spike_matrix.shape + + spike_times, spike_neurons = np.where(spike_matrix) + + if ax is None: + fig, ax = create_figure(figsize=figsize, style=style) + else: + fig = ax.get_figure() + + scatter_defaults = { + "marker": style.raster_marker, + "s": style.raster_marker_size, + "linewidths": style.raster_line_width, + "c": color, + } + scatter_defaults.update(scatter_kwargs) + + if len(spike_times) > 0: + ax.scatter(spike_times, spike_neurons, **scatter_defaults) + + if neuron_ids is not None: + # Show subset of labels if too many neurons + if n_neurons > 20: + step = n_neurons // 10 + tick_positions = list(range(0, n_neurons, step)) + tick_labels = [neuron_ids[i] for i in tick_positions] + ax.set_yticks(tick_positions) + ax.set_yticklabels(tick_labels) + else: + ax.set_yticks(range(n_neurons)) + ax.set_yticklabels(neuron_ids) + + ax.set_xlim(-0.5, n_timesteps + 0.5) + ax.set_ylim(-0.5, n_neurons - 0.5) + + style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title) + + if style.tight_layout: + fig.tight_layout() + + return fig, ax diff --git a/sanafe/viz/styles.py b/sanafe/viz/styles.py new file mode 100644 index 0000000..85c3f6e --- /dev/null +++ b/sanafe/viz/styles.py @@ -0,0 +1,340 @@ +""" +This module provides consistent styling across all SANA-FE plots, including +color palettes, figure sizes, and matplotlib configuration. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +import numpy as np + +DEFAULT_COLORS = [ + "#1f77b4", # Blue + "#ff7f0e", # Orange + "#2ca02c", # Green + "#d62728", # Red + "#9467bd", # Purple + "#8c564b", # Brown + "#e377c2", # Pink + "#7f7f7f", # Gray + "#bcbd22", # Olive + "#17becf", # Cyan +] + +# Extended palette +EXTENDED_COLORS = DEFAULT_COLORS + [ + "#aec7e8", # Light blue + "#ffbb78", # Light orange + "#98df8a", # Light green + "#ff9896", # Light red + "#c5b0d5", # Light purple + "#c49c94", # Light brown + "#f7b6d2", # Light pink + "#c7c7c7", # Light gray + "#dbdb8d", # Light olive + "#9edae5", # Light cyan +] + +NEUROMORPHIC_CMAP_COLORS = [ + "#0d0887", # Deep purple + "#46039f", + "#7201a8", + "#9c179e", + "#bd3786", + "#d8576b", + "#ed7953", + "#fb9f3a", + "#fdca26", + "#f0f921", # Bright yellow +] + + +@dataclass +class SANAFEStyle: + """ + Style configuration for plots + + colors: List of colors for different groups/series + figure_size: Default (width, height) in inches + dpi: Resolution for saved figures + font_family: Font family for text + font_size: Base font size in points + title_size: Title font size in points + label_size: Axis label font size in points + tick_size: Tick label font size in points + line_width: Default line width + marker_size: Default marker size for scatter plots + spine_width: Width of axis spines + grid: Whether to show grid by default + grid_alpha: Transparency of grid lines + tight_layout: Whether to use tight_layout by default + """ + colors: List[str] = field(default_factory=lambda: DEFAULT_COLORS.copy()) + figure_size: Tuple[float, float] = (8.0, 5.0) + dpi: int = 100 + font_family: str = "sans-serif" + font_size: float = 11.0 + title_size: float = 13.0 + label_size: float = 11.0 + tick_size: float = 10.0 + line_width: float = 1.5 + marker_size: float = 30.0 + spine_width: float = 1.0 + grid: bool = False + grid_alpha: float = 0.3 + tight_layout: bool = True + + # Raster plot specific + raster_marker: str = "|" + raster_marker_size: float = 100.0 + raster_line_width: float = 1.5 + + # Potential plot specific + potential_line_width: float = 1.5 + potential_marker: Optional[str] = None + potential_marker_size: float = 4.0 + + # Performance plot specific + perf_line_width: float = 1.5 + perf_marker: Optional[str] = "o" + perf_marker_size: float = 3.0 + perf_fill_alpha: float = 0.3 + + # Energy breakdown specific + energy_colors: List[str] = field(default_factory=lambda: [ + "#1f77b4", # Synapse - Blue + "#2ca02c", # Dendrite - Green + "#ff7f0e", # Soma - Orange + "#d62728", # Network - Red + ]) + energy_component_names: List[str] = field(default_factory=lambda: [ + "Synapse", "Dendrite", "Soma", "Network", + ]) + + # Histogram specific + hist_bins: int = 30 + hist_alpha: float = 0.7 + hist_edgecolor: str = "white" + hist_edgewidth: float = 0.5 + + def to_rc_params(self) -> Dict[str, Any]: + """Convert style to matplotlib rcParams dictionary.""" + return { + "figure.figsize": self.figure_size, + "figure.dpi": self.dpi, + "font.family": self.font_family, + "font.size": self.font_size, + "axes.titlesize": self.title_size, + "axes.labelsize": self.label_size, + "xtick.labelsize": self.tick_size, + "ytick.labelsize": self.tick_size, + "lines.linewidth": self.line_width, + "lines.markersize": self.marker_size, + "axes.linewidth": self.spine_width, + "axes.grid": self.grid, + "grid.alpha": self.grid_alpha, + "figure.autolayout": self.tight_layout, + } + + +PUBLICATION_STYLE = SANAFEStyle( + figure_size=(6.0, 4.0), + dpi=300, + font_family="serif", + font_size=10.0, + title_size=11.0, + label_size=10.0, + tick_size=9.0, + line_width=1.0, + spine_width=0.8, + raster_marker_size=80.0, + raster_line_width=1.0, +) + +PRESENTATION_STYLE = SANAFEStyle( + figure_size=(12.0, 7.0), + dpi=100, + font_size=14.0, + title_size=18.0, + label_size=14.0, + tick_size=12.0, + line_width=2.5, + spine_width=1.5, + raster_marker_size=150.0, + raster_line_width=2.5, +) + +NOTEBOOK_STYLE = SANAFEStyle( + figure_size=(10.0, 6.0), + dpi=100, + font_size=12.0, + title_size=14.0, + label_size=12.0, + tick_size=11.0, + grid=True, + grid_alpha=0.2, +) + +# Global default style +_default_style: SANAFEStyle = SANAFEStyle() + + +def get_default_style() -> SANAFEStyle: + return _default_style + + +def set_default_style(style: Optional[SANAFEStyle] = None) -> None: + """ + Set the default style for all SANA-FE plots. + """ + global _default_style + if style is None: + _default_style = SANAFEStyle() + else: + _default_style = style + + +def apply_style(style: Optional[SANAFEStyle] = None) -> None: + """ + Apply style settings to matplotlib's rcParams. + """ + if style is None: + style = _default_style + + for key, value in style.to_rc_params().items(): + mpl.rcParams[key] = value + + +def get_group_colors( + groups: Sequence[str], + style: Optional[SANAFEStyle] = None, +) -> Dict[str, str]: + """ + Assigns colors from the style's color palette to each group name. + Colors cycle if there are more groups than colors. + + Returns: Dictionary mapping group names to color strings. + + """ + if style is None: + style = _default_style + + colors = style.colors if len(style.colors) >= len(groups) else EXTENDED_COLORS + + return { + group: colors[i % len(colors)] + for i, group in enumerate(groups) + } + + +def get_colormap( + name: str = "neuromorphic", + n_colors: int = 256, +) -> LinearSegmentedColormap: + """ + Get a colormap for SANA-FE visualizations. + + Args: + name: Colormap name. Options: + - "neuromorphic": Purple-to-yellow gradient (default) + - "activity": Blue-to-red for activity levels + - "energy": Green-to-red for energy consumption + - Any matplotlib colormap name + n_colors: Number of discrete colors in the colormap + + Returns: Matplotlib colormap object. + """ + if name == "neuromorphic": + return LinearSegmentedColormap.from_list( + "neuromorphic", + NEUROMORPHIC_CMAP_COLORS, + N=n_colors, + ) + elif name == "activity": + return LinearSegmentedColormap.from_list( + "activity", + ["#2166ac", "#f7f7f7", "#b2182b"], + N=n_colors, + ) + elif name == "energy": + return LinearSegmentedColormap.from_list( + "energy", + ["#1a9850", "#ffffbf", "#d73027"], + N=n_colors, + ) + else: + return plt.get_cmap(name) + + +def create_figure( + figsize: Optional[Tuple[float, float]] = None, + style: Optional[SANAFEStyle] = None, + **kwargs, +) -> Tuple[plt.Figure, plt.Axes]: + """ + Create a figure with styling applied. + + Returns: Tuple of (Figure, Axes) objects. + """ + if style is None: + style = _default_style + + if figsize is None: + figsize = style.figure_size + + fig, ax = plt.subplots(figsize=figsize, **kwargs) + + # Apply spine styling + for spine in ax.spines.values(): + spine.set_linewidth(style.spine_width) + + # Apply grid if enabled + if style.grid: + ax.grid(True, alpha=style.grid_alpha) + + return fig, ax + + +def style_axis( + ax: plt.Axes, + style: Optional[SANAFEStyle] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + title: Optional[str] = None, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, +) -> plt.Axes: + if style is None: + style = _default_style + + # Apply spine styling + for spine in ax.spines.values(): + spine.set_linewidth(style.spine_width) + + # Apply grid if enabled + if style.grid: + ax.grid(True, alpha=style.grid_alpha) + + # Set labels and title + if xlabel is not None: + ax.set_xlabel(xlabel, fontsize=style.label_size) + if ylabel is not None: + ax.set_ylabel(ylabel, fontsize=style.label_size) + if title is not None: + ax.set_title(title, fontsize=style.title_size) + + # Set limits + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + + # Set tick label sizes + ax.tick_params(axis="both", labelsize=style.tick_size) + + return ax diff --git a/scripts/run_arch_testing.py b/scripts/run_arch_testing.py new file mode 100644 index 0000000..26aca5e --- /dev/null +++ b/scripts/run_arch_testing.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +from pathlib import Path +import sys + +REPO = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(REPO)) + +import sanafe +from sanafe.data.traces import TraceData +from sanafe.viz.raster import raster_plot +from sanafe.viz.potential import potential_plot, potential_heatmap, potential_subplots +from sanafe.viz.performance import ( + energy_breakdown_plot, throughput_plot, + latency_histogram, latency_comparison, +) + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +OUT = REPO / "tmp_arch_testing" +OUT.mkdir(exist_ok=True) + +#1. Load architecture +print("Loading architecture...") +arch = sanafe.load_arch(str(REPO / "sanafe" / "data" / "arch_testing.yaml")) + +#2. Build SNN programmatically +snn = sanafe.Network() + +# bias +inp = snn.create_neuron_group("input", 4, + model_attributes={"threshold": 1.0, "reset": 0.0}, + log_spikes=True, log_potential=True) + +# 4 neurons +hid = snn.create_neuron_group("hidden", 4, + model_attributes={"threshold": 0.8, "reset": 0.0}, + log_spikes=True, log_potential=True) + +# 2 neurons +out = snn.create_neuron_group("output", 2, + model_attributes={"threshold": 1.5, "reset": 0.0}, + log_spikes=True, log_potential=True) + +for i_n in inp: + for h_n in hid: + i_n.connect_to_neuron(h_n, {"weight": 0.5}) + +for h_n in hid: + for o_n in out: + h_n.connect_to_neuron(o_n, {"weight": 0.7}) + +inp[0].set_attributes(model_attributes={"bias": 1.2}) +inp[1].set_attributes(model_attributes={"bias": 0.9}) +inp[2].set_attributes(model_attributes={"bias": 1.1}) +inp[3].set_attributes(model_attributes={"bias": 0.6}) + +# mapped neurons to cores on different tiles +for i, n in enumerate(inp): + n.map_to_core(arch.tiles[0].cores[i]) +for i, n in enumerate(hid): + n.map_to_core(arch.tiles[1].cores[i]) +for i, n in enumerate(out): + n.map_to_core(arch.tiles[2].cores[i]) + +# 3. Simulate +print("Creating SpikingChip...") +chip = sanafe.SpikingChip(arch) +print("Loading SNN...") +chip.load(snn) +print("Running simulation (100 timesteps)...") + +results = chip.sim( + 100, + spike_trace=True, + potential_trace=True, + perf_trace=True, + message_trace=True, +) + +print(f" Total energy: {results['energy']['total']:.2e} J") +print(f" Neurons fired: {results['neurons_fired']}") +print(f" Spikes: {results['spikes']}") + +# 4. Build TraceData directly from in-memory results +tr = TraceData.from_sim_results(results) +print(f" TraceData: {tr}") + +# 5. Export CSVs using proper serialization +spike_csv = OUT / "spikes.csv" +tr.spikes_to_dataframe().to_csv(spike_csv, index=False) +print(f"Saved {spike_csv}") + +pot_csv = OUT / "potentials.csv" +tr.potentials_to_dataframe().to_csv(pot_csv, index=True) +print(f"Saved {pot_csv}") + +# 6. Generate all visualizations +print("\nGenerating plots...") + +fig, ax = raster_plot(tr, title="arch_testing — Spike Raster") +fig.savefig(OUT / "raster.png", dpi=150) +print(" Saved raster.png") + +fig, ax = potential_plot(tr, title="arch_testing — Membrane Potentials") +fig.savefig(OUT / "potentials.png", dpi=150) +print(" Saved potentials.png") + +fig, ax = potential_heatmap(tr, title="arch_testing — Potential Heatmap") +fig.savefig(OUT / "potential_heatmap.png", dpi=150) +print(" Saved potential_heatmap.png") + +fig, axes = potential_subplots(tr, ncols=2) +fig.savefig(OUT / "potential_subplots.png", dpi=150) +print(" Saved potential_subplots.png") + +fig, ax = energy_breakdown_plot(tr, title="arch_testing — Energy Breakdown") +fig.savefig(OUT / "energy_breakdown.png", dpi=150) +print(" Saved energy_breakdown.png") + +fig, ax = energy_breakdown_plot(tr, mode="stacked_bar", normalize=True, + title="arch_testing — Energy % Breakdown") +fig.savefig(OUT / "energy_normalized.png", dpi=150) +print(" Saved energy_normalized.png") + +fig, ax = throughput_plot(tr, metrics=["fired", "spikes", "hops"], + title="arch_testing — Throughput") +fig.savefig(OUT / "throughput.png", dpi=150) +print(" Saved throughput.png") + +fig, ax = latency_histogram(tr, metric="generation_delay", + title="arch_testing — Generation Delay") +fig.savefig(OUT / "latency_gen.png", dpi=150) +print(" Saved latency_gen.png") + +fig, ax = latency_comparison(tr, + metrics=["generation_delay", "receive_delay", + "network_delay", "blocked_delay"], + title="arch_testing — Latency Comparison") +fig.savefig(OUT / "latency_comparison.png", dpi=150) +print(" Saved latency_comparison.png") + +plt.close("all") +print(f"\nOutputs saved to {OUT}/") diff --git a/scripts/test_performance_viz.py b/scripts/test_performance_viz.py new file mode 100644 index 0000000..e5d8f04 --- /dev/null +++ b/scripts/test_performance_viz.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" + python3 scripts/test_performance_viz.py +""" +from pathlib import Path +import sys +import importlib.util + +REPO = Path(__file__).resolve().parent.parent + + +def _bootstrap_sanafe_viz(): + import types + for pkg in ("sanafe", "sanafe.data", "sanafe.viz"): + if pkg not in sys.modules: + mod = types.ModuleType(pkg) + mod.__path__ = [] + sys.modules[pkg] = mod + + def load(name, path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + load("sanafe.data.traces", REPO / "sanafe" / "data" / "traces.py") + load("sanafe.viz.styles", REPO / "sanafe" / "viz" / "styles.py") + load("sanafe.viz.performance", REPO / "sanafe" / "viz" / "performance.py") + + from sanafe.viz.performance import ( + energy_breakdown_plot, throughput_plot, + latency_histogram, latency_comparison, + ) + from sanafe.data.traces import TraceData + return (TraceData, energy_breakdown_plot, throughput_plot, + latency_histogram, latency_comparison) + + +def main(): + import matplotlib.pyplot as plt + + (TraceData, energy_breakdown_plot, throughput_plot, + latency_histogram, latency_comparison) = _bootstrap_sanafe_viz() + + out = REPO / "tmp_trace_test" + out.mkdir(exist_ok=True) + + # --- Load the tutorial traces that already exist on disk --- + perf_csv = REPO / "tutorial" / "perf.csv" + msg_csv = REPO / "tutorial" / "messages.csv" + tr = TraceData.from_files(perf_csv=perf_csv, message_csv=msg_csv) + print("TraceData:", tr) + + # 1. Energy breakdown — stacked area (default) + fig, ax = energy_breakdown_plot(tr, title="Energy Breakdown (stacked area)") + fig.savefig(out / "energy_stacked_area.png", dpi=100) + print("Saved", out / "energy_stacked_area.png") + + # 2. Energy breakdown — stacked bar + fig, ax = energy_breakdown_plot(tr, mode="stacked_bar", title="Energy Breakdown (stacked bar)") + fig.savefig(out / "energy_stacked_bar.png", dpi=100) + print("Saved", out / "energy_stacked_bar.png") + + # 3. Energy breakdown — normalized percentage + fig, ax = energy_breakdown_plot(tr, mode="stacked_bar", normalize=True, + title="Energy Breakdown (normalized)") + fig.savefig(out / "energy_normalized.png", dpi=100) + print("Saved", out / "energy_normalized.png") + + # 4. Throughput + fig, ax = throughput_plot(tr, title="Throughput") + fig.savefig(out / "throughput.png", dpi=100) + print("Saved", out / "throughput.png") + + # 5. Throughput — dual axis with energy + fig, ax = throughput_plot(tr, metrics=["fired", "total_energy"], + secondary_y=["total_energy"], + title="Fired vs Total Energy") + fig.savefig(out / "throughput_dual.png", dpi=100) + print("Saved", out / "throughput_dual.png") + + # 6. Latency histogram — generation delay + fig, ax = latency_histogram(tr, metric="generation_delay", + title="Generation Delay") + fig.savefig(out / "latency_generation.png", dpi=100) + print("Saved", out / "latency_generation.png") + + # 7. Latency histogram — processing delay + fig, ax = latency_histogram(tr, metric="processing_delay", + title="Processing Delay") + fig.savefig(out / "latency_processing.png", dpi=100) + print("Saved", out / "latency_processing.png") + + # 8. Latency comparison overlay + fig, ax = latency_comparison(tr, metrics=["generation_delay", "processing_delay"]) + fig.savefig(out / "latency_comparison.png", dpi=100) + print("Saved", out / "latency_comparison.png") + + plt.close("all") + print("Done. Check files in tmp_trace_test/") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_potential_viz.py b/scripts/test_potential_viz.py new file mode 100644 index 0000000..f8464c9 --- /dev/null +++ b/scripts/test_potential_viz.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +""" + python3 scripts/test_potential_viz.py +""" +from pathlib import Path +import sys +import importlib.util + +REPO = Path(__file__).resolve().parent.parent + + +def _bootstrap_sanafe_viz(): + import types + for pkg in ("sanafe", "sanafe.data", "sanafe.viz"): + if pkg not in sys.modules: + mod = types.ModuleType(pkg) + mod.__path__ = [] + sys.modules[pkg] = mod + + def load(name, path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + load("sanafe.data.traces", REPO / "sanafe" / "data" / "traces.py") + load("sanafe.viz.styles", REPO / "sanafe" / "viz" / "styles.py") + load("sanafe.viz.potential", REPO / "sanafe" / "viz" / "potential.py") + + from sanafe.viz.potential import potential_plot, potential_heatmap, potential_subplots + from sanafe.data.traces import TraceData + return TraceData, potential_plot, potential_heatmap, potential_subplots + + +def main(): + import pandas as pd + import numpy as np + + TraceData, potential_plot, potential_heatmap, potential_subplots = _bootstrap_sanafe_viz() + + out = REPO / "tmp_trace_test" + out.mkdir(exist_ok=True) + + pot_csv = out / "potentials.csv" + neuron_ids = ["in.0", "in.1", "out.0"] + pd.DataFrame({ + "in.0": [0.0, 0.1, 0.2, 0.3, 0.25, 0.1], + "in.1": [0.0, 0.0, 0.5, 0.0, 0.2, 0.0], + "out.0": [0.0, 0.0, 0.0, 0.9, 0.1, 0.0], + }).to_csv(pot_csv, index=False) + + tr = TraceData.from_files(potential_csv=pot_csv) + print("TraceData:", tr) + + # Timeseries plot + fig1, ax1 = potential_plot( + tr, + neuron_ids=neuron_ids, + title="Potential test (from CSV)", + show_threshold=0.5, + ) + fig1.savefig(out / "potential_plot.png", dpi=100) + print("Saved", out / "potential_plot.png") + + # Heatmap + fig2, ax2 = potential_heatmap( + tr, + neuron_ids=neuron_ids, + title="Potential heatmap (from CSV)", + ) + fig2.savefig(out / "potential_heatmap.png", dpi=100) + print("Saved", out / "potential_heatmap.png") + + # Subpotentials + fig3, axes = potential_subplots(tr, neuron_ids=neuron_ids, ncols=1) + fig3.savefig(out / "potential_subplots.png", dpi=100) + print("Saved", out / "potential_subplots.png") + + # Heatmap from array + arr = np.random.rand(20, 3).astype(np.float64) * 0.8 + fig4, ax4 = potential_heatmap(arr, neuron_ids=["A", "B", "C"], title="From numpy array") + fig4.savefig(out / "potential_heatmap_array.png", dpi=100) + print("Saved", out / "potential_heatmap_array.png") + + try: + import matplotlib.pyplot as plt + plt.close("all") + except Exception: + pass + print("Done. Check files in tmp_trace_test/") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_raster_viz.py b/scripts/test_raster_viz.py new file mode 100644 index 0000000..300fb29 --- /dev/null +++ b/scripts/test_raster_viz.py @@ -0,0 +1,87 @@ +from pathlib import Path +import sys +import importlib.util + +REPO = Path(__file__).resolve().parent.parent + + +def _bootstrap_sanafe_viz(): + """Load sanafe.data.traces and sanafe.viz.raster without running sanafe/__init__.py.""" + import types + for pkg in ("sanafe", "sanafe.data", "sanafe.viz"): + if pkg not in sys.modules: + mod = types.ModuleType(pkg) + mod.__path__ = [] + sys.modules[pkg] = mod + + def load(name, path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + load("sanafe.data.traces", REPO / "sanafe" / "data" / "traces.py") + load("sanafe.viz.styles", REPO / "sanafe" / "viz" / "styles.py") + load("sanafe.viz.raster", REPO / "sanafe" / "viz" / "raster.py") + + from sanafe.viz.raster import raster_plot, raster_plot_matrix + from sanafe.data.traces import TraceData + return TraceData, raster_plot, raster_plot_matrix + + +def main(): + import pandas as pd + + TraceData, raster_plot, raster_plot_matrix = _bootstrap_sanafe_viz() + + out = REPO / "tmp_trace_test" + out.mkdir(exist_ok=True) + + spikes_csv = out / "spikes.csv" + pd.DataFrame([ + {"neuron": "in.0", "timestep": 0}, + {"neuron": "in.1", "timestep": 0}, + {"neuron": "in.0", "timestep": 1}, + {"neuron": "in.1", "timestep": 2}, + {"neuron": "out.0", "timestep": 2}, + {"neuron": "out.0", "timestep": 3}, + {"neuron": "out.1", "timestep": 4}, + ]).to_csv(spikes_csv, index=False) + + tr = TraceData.from_files(spike_csv=spikes_csv) + print("TraceData:", tr) + print("Groups:", tr.get_neuron_groups()) + + fig1, ax1 = raster_plot(tr, title="Raster (from CSV, all groups)") + fig1.savefig(out / "raster_plot.png", dpi=100) + print("Saved", out / "raster_plot.png") + + fig2, ax2 = raster_plot( + tr, + groups=["in", "out"], + time_range=(0, 4), + title="Raster (groups=in,out; time 0–4)", + ) + fig2.savefig(out / "raster_plot_filtered.png", dpi=100) + print("Saved", out / "raster_plot_filtered.png") + + matrix, neuron_ids = tr.spikes_to_matrix() + fig3, ax3 = raster_plot_matrix( + matrix, + neuron_ids=neuron_ids, + title="Raster from matrix (spikes_to_matrix)", + ) + fig3.savefig(out / "raster_plot_matrix.png", dpi=100) + print("Saved", out / "raster_plot_matrix.png") + + try: + import matplotlib.pyplot as plt + plt.close("all") + except Exception: + pass + print("Done. Check files in tmp_trace_test/") + + +if __name__ == "__main__": + main()