diff --git a/README.md b/README.md index b58545c..153adc1 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice `pip install HyPyP` +> For installation with dependencies optimized algorithms, see (**Poetry installation** below). + ## Documentation HyPyP documentation of all the API functions is available online at [hypyp.readthedocs.io](https://hypyp.readthedocs.io/) @@ -51,6 +53,7 @@ For getting started with HyPyP, we have designed a little walkthrough: [getting_ 📊 [shiny/\*.py](https://github.com/ppsp-team/HyPyP/blob/master/hypyp/app) — Shiny dashboards, install using `poetry install --extras shiny` (Patrice) + ## Poetry Installation (Only for Developers and Adventurous Users) To develop HyPyP, we recommend using [Poetry 2.x](https://python-poetry.org/). Follow these steps: @@ -81,6 +84,26 @@ To install development dependencies, you can run: poetry install --with dev ``` +#### 3.1 Installing optimizations + +To use `numba` optimizations, you can run: + +```bash +poetry install --with optim_numba +``` + +To use `torch` optimizations, you can run: + +```bash +poetry install --with optim_torch +``` + +You can also update dependencies on a pre-installed HyPyP with: + +```bash +poetry sync --with optim_torch +``` + ### 4. Launch Jupyter Lab to Run Notebooks: Instead of entering a shell, launch Jupyter Lab directly within the Poetry environment: diff --git a/hypyp/analyses.py b/hypyp/analyses.py index 09efd42..f11e22f 100644 --- a/hypyp/analyses.py +++ b/hypyp/analyses.py @@ -19,9 +19,15 @@ import statsmodels.stats.multitest import copy from collections import namedtuple -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Optional import matplotlib.pyplot as plt from tqdm import tqdm +from hypyp.sync.accorr import accorr +from hypyp.sync.utils import ( + _multiply_conjugate, + _multiply_conjugate_time, + _multiply_product +) plt.ion() @@ -435,7 +441,8 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int, return result -def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True) -> np.ndarray: +def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True, + optimization: Optional[str] = None) -> np.ndarray: """ Computes frequency-domain connectivity measures from analytic signals. @@ -463,6 +470,12 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T If True, connectivity values are averaged across epochs (default) If False, epoch-by-epoch connectivity is preserved + optimization : str, optional + Allows using optimization strategies. May require extra dependencies. + See README for installation instructions. + Only available for 'accorr'. See sync.accorr.accorr for description of + the optimization options and related dependencies. + Returns ------- con : np.ndarray @@ -504,6 +517,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T # calculate all epochs at once, the only downside is that the disk may not have enough space complex_signal = complex_signal.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp) transpose_axes = (0, 1, 3, 2) + if mode.lower() == 'plv': phase = complex_signal / np.abs(complex_signal) c = np.real(phase) @@ -552,7 +566,8 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T np.sum(angle ** 2, axis=3)))) elif mode.lower() == 'accorr': - con = _accorr_hybrid(complex_signal, epochs_average=epochs_average, show_progress=True) + con = accorr(complex_signal, epochs_average=epochs_average, + show_progress=True, optimization=optimization) return con elif mode.lower() == 'pli': @@ -1074,243 +1089,4 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray], else: data = 'Please specify a valid mode: power, phase, xwt, or wtc.' print(data) - return data - - -# helper function -def _multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: - """ - Computes the product of a complex array and its conjugate efficiently. - - This helper function performs matrix multiplication between complex arrays - represented by their real and imaginary parts, collapsing the last dimension. - - Parameters - ---------- - real : np.ndarray - Real part of the complex array - - imag : np.ndarray - Imaginary part of the complex array - - transpose_axes : tuple - Axes to transpose for matrix multiplication - - Returns - ------- - product : np.ndarray - Product of the array and its complex conjugate - - Notes - ----- - This function implements the formula: - product = (real × real.T + imag × imag.T) - i(real × imag.T - imag × real.T) - - Using einsum for efficient computation without explicitly creating complex arrays. - """ - - formula = 'jilm,jimk->jilk' - product = np.einsum(formula, real, real.transpose(transpose_axes)) + \ - np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \ - (np.einsum(formula, real, imag.transpose(transpose_axes)) - \ - np.einsum(formula, imag, real.transpose(transpose_axes))) - - return product - - -# helper function -def _multiply_conjugate_time(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: - """ - Computes the product of a complex array and its conjugate without collapsing time dimension. - - Similar to _multiply_conjugate, but preserves the time dimension, which is - needed for certain connectivity metrics like wPLI. - - Parameters - ---------- - real : np.ndarray - Real part of the complex array - - imag : np.ndarray - Imaginary part of the complex array - - transpose_axes : tuple - Axes to transpose for matrix multiplication - - Returns - ------- - product : np.ndarray - Product of the array and its complex conjugate with time dimension preserved - - Notes - ----- - This function uses a different einsum formula than _multiply_conjugate: - 'jilm,jimk->jilkm' instead of 'jilm,jimk->jilk' - - This preserves the time dimension (m) in the output, which is necessary for - computing metrics that require individual time point values rather than - time-averaged products. - """ - formula = 'jilm,jimk->jilkm' - product = np.einsum(formula, real, real.transpose(transpose_axes)) + \ - np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \ - (np.einsum(formula, real, imag.transpose(transpose_axes)) - \ - np.einsum(formula, imag, real.transpose(transpose_axes))) - - return product - - -# helper function -def _multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: - """ - Computes the product of two complex arrays (not conjugate) efficiently. - - This helper function performs matrix multiplication between complex arrays - represented by their real and imaginary parts, collapsing the last dimension. - Unlike _multiply_conjugate, this computes z1 * z2 instead of z1 * conj(z2). - - Parameters - ---------- - real : np.ndarray - Real part of the complex array - - imag : np.ndarray - Imaginary part of the complex array - - transpose_axes : tuple - Axes to transpose for matrix multiplication - - Returns - ------- - product : np.ndarray - Product of the array with itself (non-conjugate) - - Notes - ----- - This function implements the formula for z1 * z2: - product = (real × real.T - imag × imag.T) + i(real × imag.T + imag × real.T) - - Using einsum for efficient computation without explicitly creating complex arrays. - This is used in the adjusted circular correlation (accorr) metric. - """ - formula = 'jilm,jimk->jilk' - product = np.einsum(formula, real, real.transpose(transpose_axes)) - \ - np.einsum(formula, imag, imag.transpose(transpose_axes)) + 1j * \ - (np.einsum(formula, real, imag.transpose(transpose_axes)) + \ - np.einsum(formula, imag, real.transpose(transpose_axes))) - - return product - - -# helper function -def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True, - show_progress: bool = True) -> np.ndarray: - """ - Computes Adjusted Circular Correlation using a hybrid approach. - - This function calculates the adjusted circular correlation coefficient between - all channel pairs. It uses a vectorized computation for the numerator and an - exact loop-based computation for the denominator. - - Parameters - ---------- - complex_signal : np.ndarray - Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) - Note: This is the already reshaped signal from compute_sync. - - epochs_average : bool, optional - If True, connectivity values are averaged across epochs (default) - If False, epoch-by-epoch connectivity is preserved - - show_progress : bool, optional - If True, display a progress bar during computation (default) - If False, no progress bar is shown - - Returns - ------- - con : np.ndarray - Adjusted circular correlation matrix with shape: - - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) - - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) - - Notes - ----- - The adjusted circular correlation is computed as: - - 1. Numerator (vectorized): Uses the difference between the absolute values of - the conjugate product and the direct product of normalized complex signals. - - 2. Denominator (loop): For each channel pair, computes optimal phase centering - parameters (m_adj, n_adj) that minimize the denominator, then calculates - the normalization factor. - - This metric provides a more accurate measure of circular correlation by - adjusting the phase centering for each channel pair individually, rather than - using a global circular mean. - - References - ---------- - Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). - Arbitrary methodological decisions skew inter-brain synchronization estimates - in hyperscanning-EEG studies. Imaging Neuroscience, 2. - https://doi.org/10.1162/imag_a_00350 - """ - n_epochs = complex_signal.shape[0] - n_freq = complex_signal.shape[1] - n_ch_total = complex_signal.shape[2] - - transpose_axes = (0, 1, 3, 2) - - # Numerator (vectorized) - z = complex_signal / np.abs(complex_signal) - c, s = np.real(z), np.imag(z) - - cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) - r_minus = np.abs(cross_conj) - - cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) - r_plus = np.abs(cross_prod) - - num = r_minus - r_plus - - # Denominator (loop) - angle = np.angle(complex_signal) - den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) - - total_pairs = (n_ch_total * (n_ch_total + 1)) // 2 - pbar = tqdm(total=total_pairs, desc=" accorr (denominator)", - disable=not show_progress, leave=False) - - for i in range(n_ch_total): - for j in range(i, n_ch_total): - alpha1 = angle[:, :, i, :] - alpha2 = angle[:, :, j, :] - - phase_diff = alpha1 - alpha2 - phase_sum = alpha1 + alpha2 - - mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=2, keepdims=True)) - mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=2, keepdims=True)) - - n_adj = -1 * (mean_diff - mean_sum) / 2 - m_adj = mean_diff + n_adj - - x_sin = np.sin(alpha1 - m_adj) - y_sin = np.sin(alpha2 - n_adj) - - den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) - den[:, :, i, j] = den_ij - den[:, :, j, i] = den_ij - - pbar.update(1) - - pbar.close() - - den = np.where(den == 0, 1, den) - con = num / den - con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch - - if epochs_average: - con = np.nanmean(con, axis=1) - - return con \ No newline at end of file + return data \ No newline at end of file diff --git a/hypyp/sync/accorr.py b/hypyp/sync/accorr.py new file mode 100644 index 0000000..1a45252 --- /dev/null +++ b/hypyp/sync/accorr.py @@ -0,0 +1,420 @@ +try: + import torch + TORCH_AVAILABLE = True + MPS_AVAILABLE = torch.backends.mps.is_available() +except ImportError: + TORCH_AVAILABLE = False + MPS_AVAILABLE = False + +try: + from numba import njit, prange + NUMBA_AVAILABLE = True +except ImportError: + NUMBA_AVAILABLE = False + +import numpy as np +from tqdm import tqdm +from numpy.typing import NDArray +from hypyp.sync.utils import _multiply_conjugate, _multiply_product +from typing import Optional + +NUMBA_OPTIMIZATION = 'numba' +TORCH_CPU_OPTIMIZATION = 'torch_cpu' +TORCH_MPS_OPTIMIZATION = 'torch_mps' + +def accorr(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True, optimization: Optional[str] = None) -> np.ndarray: + """ + Computes Adjusted Circular Correlation. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown (progress bar is lighter in this version) + + optimization : str, optional + If None, execution is done in cpu with no additional python libraries + If 'numba', execution is done in cpu using numba optimization + (just-in-time compilation) + DISCLAIMER: currently, this optimization does not provide an + enhancement as parallelization is not working (see https://github.com/ppsp-team/HyPyP/pull/246/) + If 'torch_cpu', execution is parallelized with pytorch numeric library + in cpu + If 'torch_mps', execution is parallelized with pytorch numeric library + using MPS [Apple’s Metal Performance Shaders](https://huggingface.co/docs/accelerate/usage_guides/mps) + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + + if optimization == None: + return _accorr_hybrid_precompute(complex_signal, epochs_average, show_progress) + elif optimization == NUMBA_OPTIMIZATION: + if NUMBA_AVAILABLE: + return _accorr_hybrid_precompute_numba(complex_signal, + epochs_average) + else: + raise ValueError('Numba library not available for selected optimization') + elif optimization in [TORCH_CPU_OPTIMIZATION, TORCH_MPS_OPTIMIZATION]: + if not TORCH_AVAILABLE: + raise ValueError('Torch library not available for selected optimization') + if optimization == TORCH_MPS_OPTIMIZATION: + if MPS_AVAILABLE: + return _accorr_hybrid_precompute_torch_loop(complex_signal, + epochs_average, + show_progress, + device='mps') + else: + raise ValueError('MPS not available on this device for the selected optimization') + elif optimization == TORCH_CPU_OPTIMIZATION: + return _accorr_hybrid_precompute_torch_loop(complex_signal, + epochs_average, + show_progress, + device='cpu') + else: + raise ValueError( + f'Optimization parameter is none of the accepted (' + f'{NUMBA_OPTIMIZATION}, {TORCH_CPU_OPTIMIZATION}, ' + f'{TORCH_MPS_OPTIMIZATION})' + ) + + +def _accorr_hybrid_precompute( + complex_signal: NDArray[np.complexfloating], + epochs_average: bool = True, + show_progress: bool = True +) -> NDArray[np.floating]: + """ + Computes Adjusted Circular Correlation using an optimized hybrid approach. + + This is an optimized version that pre-computes m_adj and n_adj for ALL pairs + by reusing the cross_conj and cross_prod matrices, reducing computation in the + denominator loop. + + See Also + -------- + _accorr : Main function with full parameter and return value descriptions + + Notes + ----- + Key optimization: Instead of computing mean_diff and mean_sum for each pair + in the loop, we pre-compute them for all pairs at once by reusing: + - cross_conj / n_times gives the mean of exp(i*(alpha1 - alpha2)) + - cross_prod / n_times gives the mean of exp(i*(alpha1 + alpha2)) + + This significantly reduces the number of exp() and mean() operations. + """ + n_epochs, n_freq, n_ch_total, n_times = complex_signal.shape + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes) + cross_prod = _multiply_product(c, s, transpose_axes) + + r_minus = np.abs(cross_conj) + r_plus = np.abs(cross_prod) + num = r_minus - r_plus + + # === OPTIMIZATION: Pre-compute m_adj and n_adj for ALL pairs === + # cross_conj[i,j] = sum(z_i * conj(z_j)) = sum(exp(i*(alpha_i - alpha_j))) + # cross_prod[i,j] = sum(z_i * z_j) = sum(exp(i*(alpha_i + alpha_j))) + mean_diff_all = np.angle(cross_conj / n_times) # Reuses cross_conj! + mean_sum_all = np.angle(cross_prod / n_times) # Reuses cross_prod! + + n_adj_all = -1 * (mean_diff_all - mean_sum_all) / 2 + m_adj_all = mean_diff_all + n_adj_all + + # Denominator (lighter loop - just lookups, no more circular mean computation) + angle = np.angle(complex_signal) + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + + total_pairs = (n_ch_total * (n_ch_total + 1)) // 2 + pbar = tqdm(total=total_pairs, desc=" accorr_opt (denominator)", + disable=not show_progress, leave=False) + + for i in range(n_ch_total): + for j in range(i, n_ch_total): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + # Just lookup, no more computation! + m_adj = m_adj_all[:, :, i, j, np.newaxis] + n_adj = n_adj_all[:, :, i, j, np.newaxis] + + x_sin = np.sin(alpha1 - m_adj) + y_sin = np.sin(alpha2 - n_adj) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + pbar.update(1) + + pbar.close() + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +if NUMBA_AVAILABLE: + # TODO(@m2march): research why parallelization is not working + @njit(parallel=False, cache=True) + def _accorr_den_calc_precalc(n_epochs : int, n_freq : int, n_ch_total : int, + angle : np.array, m_adj_all : np.array, + n_adj_all : np.array): + """ + Computes denominator for adjusted circular correlation using precomputed m_adj and n_adj. + + This helper function is JIT-compiled with numba for performance. + It computes the denominator values for all channel pairs using lookup tables of + precomputed m_adj and n_adj values. + + Parameters + ---------- + n_epochs : int + Number of epochs + n_freq : int + Number of frequency bands + n_ch_total : int + Total number of channels + angle : np.ndarray + Phase angles with shape (n_epochs, n_freq, n_ch_total, n_times) + m_adj_all : np.ndarray + Precomputed m_adj values with shape (n_epochs, n_freq, n_ch_total, n_ch_total) + n_adj_all : np.ndarray + Precomputed n_adj values with shape (n_epochs, n_freq, n_ch_total, n_ch_total) + + Returns + ------- + den : np.ndarray + Denominator matrix with shape (n_epochs, n_freq, n_ch_total, n_ch_total) + """ + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + + for i in range(den.shape[2]): + for j in range(i, den.shape[3]): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + # Just lookup, no more computation! + m_adj = m_adj_all[:, :, i, j] + n_adj = n_adj_all[:, :, i, j] + + x = alpha1.copy() + for xi in range(x.shape[0]): + for xj in range(x.shape[1]): + for xk in range(x.shape[2]): + x[xi, xj, xk] -= m_adj[xi, xj] + x_sin = np.sin(x) + + y = alpha2.copy() + for yi in range(y.shape[0]): + for yj in range(y.shape[1]): + for yk in range(y.shape[2]): + y[yi, yj, yk] -= n_adj[yi, yj] + y_sin = np.sin(y) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + return den + + + def _accorr_hybrid_precompute_numba( + complex_signal: NDArray[np.complexfloating], + epochs_average: bool = True, + ) -> NDArray[np.floating]: + """ + Computes Adjusted Circular Correlation using numba for optimization. + + Notes + ----- + This optimized version pre-compiles and parallelizes the main loops in + _accorr_hybrid_precompute using the numba library. + + See Also + -------- + _accorr : Main function with full parameter and return value descriptions + """ + n_epochs, n_freq, n_ch_total, n_times = complex_signal.shape + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes) + cross_prod = _multiply_product(c, s, transpose_axes) + + r_minus = np.abs(cross_conj) + r_plus = np.abs(cross_prod) + num = r_minus - r_plus + + # === OPTIMIZATION: Pre-compute m_adj and n_adj for ALL pairs === + # cross_conj[i,j] = sum(z_i * conj(z_j)) = sum(exp(i*(alpha_i - alpha_j))) + # cross_prod[i,j] = sum(z_i * z_j) = sum(exp(i*(alpha_i + alpha_j))) + mean_diff_all = np.angle(cross_conj / n_times) # Reuses cross_conj! + mean_sum_all = np.angle(cross_prod / n_times) # Reuses cross_prod! + + n_adj_all = -1 * (mean_diff_all - mean_sum_all) / 2 + m_adj_all = mean_diff_all + n_adj_all + + # Denominator (lighter loop - just lookups, no more circular mean computation) + angle = np.angle(complex_signal) + + den = _accorr_den_calc_precalc(n_epochs, n_freq, n_ch_total, angle, m_adj_all, n_adj_all) + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +if TORCH_AVAILABLE: + def _accorr_hybrid_precompute_torch_loop( + complex_signal: NDArray[np.complexfloating], + epochs_average: bool = True, + show_progress: bool = True, + device = 'cpu' + ) -> NDArray[np.floating]: + """ + Computes Adjusted Circular Correlation using pytorch for optimization. + + Parameters + ---------- + device : str + If 'cpu', computations are carried out in cpu + If 'mps', computations are carried out using + [Apple’s Metal Performance Shaders](https://huggingface.co/docs/accelerate/usage_guides/mps) + + Notes + ----- + This version using pytorch numeric operation libraries for optimization. + It also allows using special hardware (MPS) for optimization by pushing + the precalculated vectors to the device and running the main loops over + them (see _accorr_hybrid_precompute for details on the precomputation) + + See Also + -------- + _accorr : Main function with full parameter and return value descriptions + """ + SUPPORTED_DEVICES = ['cpu', 'mps'] + if device not in SUPPORTED_DEVICES: + raise ValueError(f'Unsupported device requested, must be one of {SUPPORTED_DEVICES}') + + if device == 'mps': + if not MPS_AVAILABLE: + raise ValueError(f'MSP device requested, but not supported in this device') + else: + float_type = torch.float32 + complex_type = torch.complex64 + else: + float_type = torch.float64 + complex_type = torch.complex128 + + + # Convert to torch tensors (use double precision to match numpy) + complex_tensor = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + + n_epochs, n_freq, n_ch_total, n_times = complex_tensor.shape + + # Numerator (vectorized) + z = complex_tensor / torch.abs(complex_tensor) + c, s = z.real, z.imag + + # Cross products using einsum + formula = 'efit,efjt->efij' + + # _multiply_conjugate: (real × real.T + imag × imag.T) - i(real × imag.T - imag × real.T) + cross_conj = (torch.einsum(formula, c, c) + torch.einsum(formula, s, s)) - 1j * \ + (torch.einsum(formula, c, s) - torch.einsum(formula, s, c)) + + # _multiply_product: (real × real.T - imag × imag.T) + i(real × imag.T + imag × real.T) + cross_prod = (torch.einsum(formula, c, c) - torch.einsum(formula, s, s)) + 1j * \ + (torch.einsum(formula, c, s) + torch.einsum(formula, s, c)) + + r_minus = torch.abs(cross_conj) + r_plus = torch.abs(cross_prod) + num = r_minus - r_plus + + # Pre-compute m_adj and n_adj for ALL pairs + mean_diff_all = torch.angle(cross_conj / n_times) + mean_sum_all = torch.angle(cross_prod / n_times) + + n_adj_all = -0.5 * (mean_diff_all - mean_sum_all) + m_adj_all = mean_diff_all + n_adj_all + + # Denominator - loop-based but on device + angle = torch.angle(complex_tensor) + den = torch.zeros((n_epochs, n_freq, n_ch_total, n_ch_total), device=device, dtype=float_type) + + total_pairs = (n_ch_total * (n_ch_total + 1)) // 2 + pbar = tqdm(total=total_pairs, desc=" accorr_torch (denominator)", + disable=not show_progress, leave=False) + + for i in range(n_ch_total): + for j in range(i, n_ch_total): + alpha1 = angle[:, :, i, :] # [e, f, t] + alpha2 = angle[:, :, j, :] # [e, f, t] + + # Lookup precomputed values + m_adj = m_adj_all[:, :, i, j].unsqueeze(-1) # [e, f, 1] + n_adj = n_adj_all[:, :, i, j].unsqueeze(-1) # [e, f, 1] + + x_sin = torch.sin(alpha1 - m_adj) + y_sin = torch.sin(alpha2 - n_adj) + + den_ij = 2 * torch.sqrt(torch.sum(x_sin**2, dim=2) * torch.sum(y_sin**2, dim=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + pbar.update(1) + + pbar.close() + + # Avoid division by zero + den = torch.where(den == 0, torch.ones_like(den), den) + + # Compute connectivity + con = num / den + con = con.permute(1, 0, 2, 3) # [n_freq, n_epochs, n_ch, n_ch] + + if epochs_average: + con = torch.nanmean(con, dim=1) + + # Convert back to numpy + return con.cpu().numpy() + diff --git a/hypyp/sync/utils.py b/hypyp/sync/utils.py new file mode 100644 index 0000000..fd3f5b8 --- /dev/null +++ b/hypyp/sync/utils.py @@ -0,0 +1,125 @@ +import numpy as np + +# helper function +def _multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: + """ + Computes the product of a complex array and its conjugate efficiently. + + This helper function performs matrix multiplication between complex arrays + represented by their real and imaginary parts, collapsing the last dimension. + + Parameters + ---------- + real : np.ndarray + Real part of the complex array + + imag : np.ndarray + Imaginary part of the complex array + + transpose_axes : tuple + Axes to transpose for matrix multiplication + + Returns + ------- + product : np.ndarray + Product of the array and its complex conjugate + + Notes + ----- + This function implements the formula: + product = (real × real.T + imag × imag.T) - i(real × imag.T - imag × real.T) + + Using einsum for efficient computation without explicitly creating complex arrays. + """ + + formula = 'jilm,jimk->jilk' + product = np.einsum(formula, real, real.transpose(transpose_axes)) + \ + np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \ + (np.einsum(formula, real, imag.transpose(transpose_axes)) - \ + np.einsum(formula, imag, real.transpose(transpose_axes))) + + return product + + +# helper function +def _multiply_conjugate_time(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: + """ + Computes the product of a complex array and its conjugate without collapsing time dimension. + + Similar to _multiply_conjugate, but preserves the time dimension, which is + needed for certain connectivity metrics like wPLI. + + Parameters + ---------- + real : np.ndarray + Real part of the complex array + + imag : np.ndarray + Imaginary part of the complex array + + transpose_axes : tuple + Axes to transpose for matrix multiplication + + Returns + ------- + product : np.ndarray + Product of the array and its complex conjugate with time dimension preserved + + Notes + ----- + This function uses a different einsum formula than _multiply_conjugate: + 'jilm,jimk->jilkm' instead of 'jilm,jimk->jilk' + + This preserves the time dimension (m) in the output, which is necessary for + computing metrics that require individual time point values rather than + time-averaged products. + """ + formula = 'jilm,jimk->jilkm' + product = np.einsum(formula, real, real.transpose(transpose_axes)) + \ + np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \ + (np.einsum(formula, real, imag.transpose(transpose_axes)) - \ + np.einsum(formula, imag, real.transpose(transpose_axes))) + + return product + + +# helper function +def _multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: + """ + Computes the product of two complex arrays (not conjugate) efficiently. + + This helper function performs matrix multiplication between complex arrays + represented by their real and imaginary parts, collapsing the last dimension. + Unlike _multiply_conjugate, this computes z1 * z2 instead of z1 * conj(z2). + + Parameters + ---------- + real : np.ndarray + Real part of the complex array + + imag : np.ndarray + Imaginary part of the complex array + + transpose_axes : tuple + Axes to transpose for matrix multiplication + + Returns + ------- + product : np.ndarray + Product of the array with itself (non-conjugate) + + Notes + ----- + This function implements the formula for z1 * z2: + product = (real × real.T - imag × imag.T) + i(real × imag.T + imag × real.T) + + Using einsum for efficient computation without explicitly creating complex arrays. + This is used in the adjusted circular correlation (accorr) metric. + """ + formula = 'jilm,jimk->jilk' + product = np.einsum(formula, real, real.transpose(transpose_axes)) - \ + np.einsum(formula, imag, imag.transpose(transpose_axes)) + 1j * \ + (np.einsum(formula, real, imag.transpose(transpose_axes)) + \ + np.einsum(formula, imag, real.transpose(transpose_axes))) + + return product diff --git a/pyproject.toml b/pyproject.toml index 4b83e58..4cb4114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,22 @@ pyxdf = ">=1.17.0" urllib3 = ">=2.5.0" requests = ">=2.32.4" pillow = ">=11.3.0" +snirf = "^0.8.0" +mne-icalabel = "^0.8.1" snirf = ">=0.8.0" +[tool.poetry.group.optim_torch] +optional = true + +[tool.poetry.group.optim_torch.dependencies] +torch = "^2.10.0" + +[tool.poetry.group.optim_numba] +optional = true + +[tool.poetry.group.optim_numba.dependencies] +numba = "^0.63.1" + [tool.poetry.group.dev.dependencies] mkdocs = ">=1.3.0" mkdocs-material = ">=8.2.15" @@ -70,6 +84,13 @@ tabulate = ">=0.9.0" notebook = ">=7.4.3" pyqt6 = ">=6.10.2" +[tool.poetry.group.dev_benchmark] +optional = true + +[tool.poetry.group.dev_benchmark.dependencies] +doit = "^0.36.0" +dfply = "^0.3.3" + [tool.poetry.group.shiny.dependencies] shiny = ">=1.1.0" diff --git a/tests/benchmark/accorr/dodo.py b/tests/benchmark/accorr/dodo.py new file mode 100644 index 0000000..919be85 --- /dev/null +++ b/tests/benchmark/accorr/dodo.py @@ -0,0 +1,272 @@ +import mne +import time +import numpy as np +import pickle +import json +import pandas as pd +import dfply as df +import seaborn as sns +from pathlib import Path +from collections import OrderedDict +from hypyp import analyses +from tests.hypyp.sync.accorr import accorr_reference +from hypyp.sync.accorr import accorr +import numba + +""" +Benchmark script comparing different optimization approaches for the Adjusted Circular Correlation (accorr) metric. +""" + +# Define frequency bands as a dictionary +freq_bands = { + 'Alpha-Low': [7.5, 11], + 'Alpha-High': [11.5, 13] +} + +# Convert to an OrderedDict to keep the defined order +freq_bands = OrderedDict(freq_bands) +print('Frequency bands:', freq_bands) + +preproc_S1 = mne.read_epochs('../../data/preproc_S1.fif') +preproc_S2 = mne.read_epochs('../../data/preproc_S2.fif') + +sampling_rate = preproc_S1.info['sfreq'] + +scaling_results = {method: [] for method in ['original', 'numba']} +scaling_times = {method: [] for method in ['original', 'numba']} + +def numba_run(cpu_num): + def f(complex_signal: np.ndarray, + epochs_average: bool = True, + show_progress: bool = True): + numba.set_num_threads(cpu_num) + return accorr(complex_signal, epochs_average, show_progress, optimization='numba') + return f + + +def torch_run(device): + def f(complex_signal: np.ndarray, + epochs_average: bool = True, + show_progress: bool = True): + return accorr(complex_signal, epochs_average, show_progress, + optimization=f'torch_{device}') + return f + + +method_dict = { + 'original': accorr_reference, + 'precomputed': accorr, + 'numba4': numba_run(4), + 'numba8': numba_run(8), + 'torch_cpu': torch_run('cpu'), + 'torch_mps': torch_run('mps'), +} + +numba_palette = sns.light_palette('C2', 3) +torch_palette = sns.light_palette('C3', 4) +method_palette = OrderedDict({ + 'original': 'C0', + 'precomputed': 'C1', + 'numba4': numba_palette[1], + 'numba8': numba_palette[2], + 'torch_cpu': torch_palette[1], + 'torch_mps': torch_palette[2], +}) + +out_path = Path('results') + +def multiply_channels(epochs, i): + ch_names = [f'{ch}{x}' for ch in epochs.ch_names for x in range(i)] + n_info = mne.create_info(ch_names, epochs.info['sfreq']) + return mne.EpochsArray(np.concatenate([epochs.get_data()] * i, axis=1), info=n_info) + + +def benchmark(method, epoch_multiplier, channel_multiplier): + # Create scaled dataset by concatenating + expanded_preproc_S1 = multiply_channels(preproc_S1, channel_multiplier) + expanded_preproc_S2 = multiply_channels(preproc_S2, channel_multiplier) + + epochs_list_S1 = [expanded_preproc_S1.copy() for _ in range(epoch_multiplier)] + epochs_list_S2 = [expanded_preproc_S2.copy() for _ in range(epoch_multiplier)] + + preproc_S1_scaled = mne.concatenate_epochs(epochs_list_S1) + preproc_S2_scaled = mne.concatenate_epochs(epochs_list_S2) + + # Prepare data for connectivity analysis + data_inter = np.array([preproc_S1_scaled, preproc_S2_scaled]) + + complex_signal = analyses.compute_freq_bands( + data_inter, + sampling_rate, + freq_bands, + filter_length=int(sampling_rate), + l_trans_bandwidth=5.0, + h_trans_bandwidth=5.0 + ) + n_epoch, n_ch, n_freq, n_samp = complex_signal.shape[1], complex_signal.shape[2], \ + complex_signal.shape[3], complex_signal.shape[4] + + # calculate all epochs at once, the only downside is that the disk may not have enough space + complex_signal = complex_signal.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp) + + print(complex_signal.shape) + print(f"\n Testing {method}...") + try: + st = time.perf_counter() + result = method_dict[method](complex_signal, epochs_average=False, show_progress=True) + et = time.perf_counter() + perf_time = et - st + print(f" Time: {perf_time:.4f}s") + + result_path = out_path / f'original-e{epoch_multiplier}-c{channel_multiplier}_result.pkl' + + is_ok = None + max_diff = None + if method != 'original': + if result_path.is_file(): + with open(result_path, 'rb') as f: + orig_result = pickle.load(f) + + # Compare result with orig_result + max_diff = np.max(np.abs(result - orig_result)) + is_ok = np.allclose(result, orig_result, rtol=1e-9, atol=1e-10) + + return { + 'method': method, + 'epoch_multiplier': int(epoch_multiplier), + 'channel_multiplier': int(channel_multiplier), + 'time': float(perf_time), + 'result': result, + 'is_ok': is_ok, + 'max_diff': float(max_diff) if max_diff is not None else None + } + + except Exception as e: + raise e + + +benchmark_configs = pd.DataFrame( + [ + [1, 1], + [2, 1], + [3, 1], + [5, 1], + [8, 1], + [10, 1], + [1, 2], + [1, 4], + [1, 8], + [3, 2], + [3, 4], + [3, 8], + ], + columns=['epoch_multiplier', 'channel_multiplier'] +) + +def task_calc_benchmarks(): + "Executes the different optimizations on different problem sizes." + def benchmark_action(method, epoch_multiplier, channels_multiplier, targets): + res = benchmark(method, epoch_multiplier, channels_multiplier) + + if len(targets) > 1: + with open(targets[1], 'wb') as f: + pickle.dump(res['result'], f) + + del res['result'] + with open(targets[0], 'w') as f: + print(res) + json.dump(res, f) + + if not out_path.is_dir(): + out_path.mkdir() + + for _, r in benchmark_configs.iterrows(): + for method in method_dict: + name = f'{method}-e{r["epoch_multiplier"]}-c{r["channel_multiplier"]}' + json_out = out_path / (name + '.json') + result_out = out_path / (name + '_result.pkl') + yield { + 'name': name, + 'actions': [(benchmark_action, (method, r['epoch_multiplier'], r['channel_multiplier']))], + 'targets': [json_out] + ([result_out] * (method == 'original')), + 'uptodate': [json_out.is_file()], + 'file_dep': ([str(result_out).replace(method, 'original')] * int(method != 'original')) + } + + +def get_perfs(): + res = [] + for p in out_path.glob('*.json'): + with open(p) as f: + j = json.load(f) + res.append(j) + + base_ch_num = len(preproc_S1.info.ch_names) + len(preproc_S2.info.ch_names) + base_epoch_num = preproc_S1.get_data().shape[0] + perfs = ( + pd.DataFrame.from_records(res) + >> df.mutate( + channels = df.X.channel_multiplier * base_ch_num, + epochs = df.X.epoch_multiplier * base_epoch_num, + ) + ) + return perfs + + +def task_summary_plots(): + "Generates summary plot of the speed-up of different optimizations" + def action(targets): + perfs = get_perfs() + + # Calculate speedup relative to original method + speedup_data = [] + for (ch, ep), group in perfs.groupby(['channels', 'epochs']): + original_time = group[group['method'] == 'original']['time'].values + if len(original_time) > 0: + original_time = original_time[0] + for _, row in group.iterrows(): + speedup = original_time / row['time'] + speedup_data.append({ + 'method': row['method'], + 'channels': ch, + 'epochs': ep, + 'speedup': speedup + }) + + speedup_df = pd.DataFrame(speedup_data) + + fg = sns.catplot( + speedup_df, + x='channels', + y='speedup', + col='epochs', + col_wrap=3, + hue='method', + hue_order=method_palette.keys(), + palette=method_palette, + kind='bar', + sharey=False + ) + fg.savefig(targets[0]) + + return { + 'actions': [action], + 'targets': [out_path / 'benchmark.pdf'], + 'uptodate': [False] + } + + +def task_bad_perfs(): + "Outputs a csv listing optimizations that do not equal the original results within tolerance." + def action(targets): + bad_perfs = ( + get_perfs() + >> df.mask(df.X.is_ok == False) + ) + bad_perfs.to_csv(targets[0]) + + return { + 'actions': [action], + 'targets': [out_path / 'bad_perfs.csv'], + 'uptodate': [False] + } diff --git a/tests/conftest.py b/tests/conftest.py index 50eeb25..dae5289 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import pytest import os -from collections import namedtuple +from collections import namedtuple, OrderedDict +import numpy as np import mne -from hypyp import utils +from hypyp import utils, analyses @pytest.fixture(scope="module") @@ -20,3 +21,60 @@ def epochs(): epochsTuple = namedtuple('epochs', ['epo1', 'epo2', 'epoch_merge']) return epochsTuple(epo1=epo1, epo2=epo2, epoch_merge=epoch_merge) + + +@pytest.fixture(scope="module") +def preprocessed_epochs(): + """ + Loading preprocessed test data for accorr optimization tests + """ + test_dir = os.path.dirname(__file__) + data_dir = os.path.join(test_dir, "data") + + epo1 = mne.read_epochs(os.path.join(data_dir, "preproc_S1.fif"), preload=True) + epo2 = mne.read_epochs(os.path.join(data_dir, "preproc_S2.fif"), preload=True) + + mne.epochs.equalize_epoch_counts([epo1, epo2]) + epoch_merge = utils.merge(epo1, epo2) + + preprocessedTuple= namedtuple('preprocessed_epochs', ['pepo1', 'pepo2', 'pepochs_merge']) + return preprocessedTuple(pepo1=epo1, pepo2=epo2, pepochs_merge=epoch_merge) + + +@pytest.fixture(scope="module") +def complex_signal(preprocessed_epochs): + """ + Compute complex analytic signals for accorr testing. + + Returns complex_signal with shape (n_epochs, n_freq, 2*n_channels, n_times) + ready for accorr computation. + """ + # Define frequency bands + freq_bands = OrderedDict({ + 'Alpha-Low': [7.5, 11], + 'Alpha-High': [11.5, 13] + }) + + # Stack participant data + data_inter = np.array([preprocessed_epochs.pepo1.get_data(), + preprocessed_epochs.pepo2.get_data()]) + sampling_rate = preprocessed_epochs.pepo1.info['sfreq'] + + # Compute frequency bands + # Returns shape: (n_participants=2, n_epochs, n_freq, n_ch, n_times) + complex_signal_raw = analyses.compute_freq_bands( + data_inter, + sampling_rate, + freq_bands, + filter_length=int(sampling_rate), + l_trans_bandwidth=5.0, + h_trans_bandwidth=5.0 + ) + + n_epoch, n_ch, n_freq, n_samp = complex_signal_raw.shape[1], complex_signal_raw.shape[2], \ + complex_signal_raw.shape[3], complex_signal_raw.shape[4] + + # calculate all epochs at once, the only downside is that the disk may not have enough space + complex_signal_reshaped = complex_signal_raw.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp) + + return complex_signal_reshaped diff --git a/tests/data/preproc_S1.fif b/tests/data/preproc_S1.fif new file mode 100644 index 0000000..852a0ae Binary files /dev/null and b/tests/data/preproc_S1.fif differ diff --git a/tests/data/preproc_S2.fif b/tests/data/preproc_S2.fif new file mode 100644 index 0000000..f3e69a3 Binary files /dev/null and b/tests/data/preproc_S2.fif differ diff --git a/tests/hypyp/sync/accorr.py b/tests/hypyp/sync/accorr.py new file mode 100644 index 0000000..a5e7207 --- /dev/null +++ b/tests/hypyp/sync/accorr.py @@ -0,0 +1,135 @@ +""" +Reference implementations for testing optimized versions. + +This module contains the original, unoptimized implementations that serve as +ground truth for validating optimized versions. All optimized implementations +must produce numerically identical results to these reference versions. +""" + +import numpy as np +from tqdm import tqdm +from hypyp.analyses import _multiply_conjugate, _multiply_product + + +def accorr_reference( + complex_signal: np.ndarray, + epochs_average: bool = True, + show_progress: bool = False, +) -> np.ndarray: + """ + Reference implementation of Adjusted Circular Correlation (unoptimized). + + This is the original implementation using nested loops for the denominator + calculation. It serves as the ground truth for testing optimized versions. + + All optimized implementations in hypyp.sync.accorr must produce results + that match this reference within numerical precision (typically < 1e-9). + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default: False for tests) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed as: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (loop): For each channel pair, computes optimal phase centering + parameters (m_adj, n_adj) that minimize the denominator, then calculates + the normalization factor. + + This metric provides a more accurate measure of circular correlation by + adjusting the phase centering for each channel pair individually, rather than + using a global circular mean. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (loop) - UNOPTIMIZED REFERENCE VERSION + angle = np.angle(complex_signal) + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + + total_pairs = (n_ch_total * (n_ch_total + 1)) // 2 + pbar = tqdm( + total=total_pairs, + desc=" accorr_reference (denominator)", + disable=not show_progress, + leave=False, + ) + + for i in range(n_ch_total): + for j in range(i, n_ch_total): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + phase_diff = alpha1 - alpha2 + phase_sum = alpha1 + alpha2 + + mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=2, keepdims=True)) + mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=2, keepdims=True)) + + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + x_sin = np.sin(alpha1 - m_adj) + y_sin = np.sin(alpha2 - n_adj) + + den_ij = 2 * np.sqrt( + np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2) + ) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + pbar.update(1) + + pbar.close() + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con diff --git a/tests/hypyp/sync/depricated_accorr_optimization.py b/tests/hypyp/sync/depricated_accorr_optimization.py new file mode 100644 index 0000000..d993392 --- /dev/null +++ b/tests/hypyp/sync/depricated_accorr_optimization.py @@ -0,0 +1,834 @@ +from numba import njit, prange +import numpy as np +from numpy.typing import NDArray +import torch +from hypyp.sync.utils import _multiply_conjugate, _multiply_conjugate_time, _multiply_product + +def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using a hybrid approach. + + This function calculates the adjusted circular correlation coefficient between + all channel pairs. It uses a vectorized computation for the numerator and an + exact loop-based computation for the denominator. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed as: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (loop): For each channel pair, computes optimal phase centering + parameters (m_adj, n_adj) that minimize the denominator, then calculates + the normalization factor. + + This metric provides a more accurate measure of circular correlation by + adjusting the phase centering for each channel pair individually, rather than + using a global circular mean. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (loop) + angle = np.angle(complex_signal) + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + + total_pairs = (n_ch_total * (n_ch_total + 1)) // 2 + pbar = tqdm(total=total_pairs, desc=" accorr (denominator)", + disable=not show_progress, leave=False) + + for i in range(n_ch_total): + for j in range(i, n_ch_total): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + phase_diff = alpha1 - alpha2 + phase_sum = alpha1 + alpha2 + + mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=2, keepdims=True)) + mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=2, keepdims=True)) + + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + x_sin = np.sin(alpha1 - m_adj) + y_sin = np.sin(alpha2 - n_adj) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + pbar.update(1) + + pbar.close() + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +@njit(nopython=False, parallel=True, cache=True) +def _accorr_den_calc(n_epochs, n_freq, n_ch_total, angle): + + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + for i in prange(den.shape[2]): + for j in prange(i, den.shape[3]): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + phase_diff = alpha1 - alpha2 + phase_sum = alpha1 + alpha2 + + def axis2_mean(m): + return np.array([ + m[i, j, :].mean() + for i in prange(m.shape[0]) + for j in prange(m.shape[1]) + ]).reshape((m.shape[0], m.shape[1])) + + mean_diff = np.angle(axis2_mean(np.exp(1j * phase_diff))) + mean_sum = np.angle(axis2_mean(np.exp(1j * phase_sum))) + + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + x = alpha1.copy() + for xi in prange(x.shape[0]): + for xj in prange(x.shape[1]): + for xk in prange(x.shape[2]): + x[xi, xj, xk] -= m_adj[xi, xj] + x_sin = np.sin(x) + + y = alpha2.copy() + for yi in prange(y.shape[0]): + for yj in prange(y.shape[1]): + for yk in prange(y.shape[2]): + y[yi, yj, yk] -= n_adj[yi, yj] + y_sin = np.sin(y) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + return den + +def _accorr_hybrid_numba(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using a hybrid approach. + + This function calculates the adjusted circular correlation coefficient between + all channel pairs. It uses a vectorized computation for the numerator and an + exact loop-based computation for the denominator. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed as: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (loop): For each channel pair, computes optimal phase centering + parameters (m_adj, n_adj) that minimize the denominator, then calculates + the normalization factor. + + This metric provides a more accurate measure of circular correlation by + adjusting the phase centering for each channel pair individually, rather than + using a global circular mean. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (loop) + angle = np.angle(complex_signal) + + den = _accorr_den_calc(n_epochs, n_freq, n_ch_total, angle) + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +def _accorr_hybrid_vectorized(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using full vectorization. + + This function calculates the adjusted circular correlation coefficient between + all channel pairs using fully vectorized operations, eliminating nested loops + for significant performance improvements. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed using full vectorization: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (vectorized): Computes optimal phase centering parameters + (m_adj, n_adj) for all channel pairs simultaneously using broadcasting, + then calculates the normalization factor without loops. + + This fully vectorized approach provides significant speedup compared to the + loop-based implementation while maintaining numerical accuracy. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + n_times = complex_signal.shape[3] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (fully vectorized) + angle = np.angle(complex_signal) + + # Expand dimensions for broadcasting + # angle shape: (n_epochs, n_freq, n_ch_total, n_times) + alpha1_all = angle[:, :, :, None, :] # (n_epochs, n_freq, n_ch_total, 1, n_times) + alpha2_all = angle[:, :, None, :, :] # (n_epochs, n_freq, 1, n_ch_total, n_times) + + # Compute phase differences and sums for all pairs + phase_diff = alpha1_all - alpha2_all # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + phase_sum = alpha1_all + alpha2_all # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + + mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=4, keepdims=True)) + mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=4, keepdims=True)) + + # Compute optimal phase centering parameters + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + # Compute sine deviations for all pairs + x_sin = np.sin(alpha1_all - m_adj) # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + y_sin = np.sin(alpha2_all - n_adj) # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + + # Sum of squared sines + x_sin_sq_sum = np.sum(x_sin**2, axis=4) # (n_epochs, n_freq, n_ch_total, n_ch_total) + y_sin_sq_sum = np.sum(y_sin**2, axis=4) # (n_epochs, n_freq, n_ch_total, n_ch_total) + + # Compute denominator + den = 2 * np.sqrt(x_sin_sq_sum * y_sin_sq_sum) + + # Handle division by zero + den = np.where(den == 0, 1, den) + + # Compute connectivity + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +def _compute_pair_denominator(args): + """Helper function for multiprocessing implementation.""" + i, j, angle = args + n_epochs, n_freq = angle.shape[0], angle.shape[1] + + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + phase_diff = alpha1 - alpha2 + phase_sum = alpha1 + alpha2 + + mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=2, keepdims=True)) + mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=2, keepdims=True)) + + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + x_sin = np.sin(alpha1 - m_adj) + y_sin = np.sin(alpha2 - n_adj) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + + return i, j, den_ij + + +def _accorr_hybrid_multiprocessing(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True, n_jobs: int = 4) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using multiprocessing parallelization. + + This implementation distributes channel pair computations across multiple + processes for improved performance on multi-core systems. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + + n_jobs : int, optional + Number of parallel processes to use (default: 4) + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (multiprocessing) + angle = np.angle(complex_signal) + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + + # Prepare list of channel pairs + pairs = [(i, j, angle) for i in range(n_ch_total) for j in range(i, n_ch_total)] + total_pairs = len(pairs) + + pbar = tqdm(total=total_pairs, desc=" accorr (denominator)", + disable=not show_progress, leave=False) + + # Use 'fork' context on Unix systems (macOS, Linux) for better performance + # On Windows, this will fall back to 'spawn' + ctx = get_context('fork' if sys.platform != 'win32' else 'spawn') + + with ctx.Pool(processes=n_jobs) as pool: + for i, j, den_ij in pool.imap_unordered(_compute_pair_denominator, pairs): + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + pbar.update(1) + + pbar.close() + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +def _accorr_hybrid_precompute_torch( + complex_signal: NDArray[np.complexfloating], + epochs_average: bool = True, + show_progress: bool = True, + device = 'cpu' +) -> NDArray[np.floating]: + """ + PyTorch-optimized version of Adjusted Circular Correlation with precomputation. + + Uses PyTorch for GPU acceleration and efficient tensor operations. This version + computes the entire denominator matrix without loops using advanced broadcasting. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + This implementation uses PyTorch to: + 1. Leverage GPU acceleration if available + 2. Eliminate all Python loops using advanced tensor broadcasting + 3. Optimize memory access patterns for better cache utilization + + The denominator computation is fully vectorized by expanding dimensions and + using broadcasting to compute all channel pairs simultaneously. + """ + if device == 'cuda': + if not torch.cuda.is_available(): + raise ValueError('CUDA is not available on this computer') + elif device == 'mps': + if not torch.backends.mps.is_available(): + raise ValueError('MPS is not available on this computer') + + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + # Convert to torch tensors (use double precision to match numpy) + complex_tensor = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + + n_epochs, n_freq, n_ch_total, n_times = complex_tensor.shape + + + # Numerator (vectorized) + z = complex_tensor / torch.abs(complex_tensor) + c, s = z.real, z.imag + + # Cross products using einsum - matching the numpy implementation formula: 'jilm,jimk->jilk' + # where j=epoch, i=freq, l=ch1, m/k=time + formula = 'efit,efjt->efij' + + # _multiply_conjugate: (real × real.T + imag × imag.T) - i(real × imag.T - imag × real.T) + cross_conj = (torch.einsum(formula, c, c) + torch.einsum(formula, s, s)) - 1j * \ + (torch.einsum(formula, c, s) - torch.einsum(formula, s, c)) + + # _multiply_product: (real × real.T - imag × imag.T) + i(real × imag.T + imag × real.T) + cross_prod = (torch.einsum(formula, c, c) - torch.einsum(formula, s, s)) + 1j * \ + (torch.einsum(formula, c, s) + torch.einsum(formula, s, c)) + + r_minus = torch.abs(cross_conj) + r_plus = torch.abs(cross_prod) + num = r_minus - r_plus + + # Pre-compute m_adj and n_adj for ALL pairs + mean_diff_all = torch.angle(cross_conj / n_times) + mean_sum_all = torch.angle(cross_prod / n_times) + + n_adj_all = -0.5 * (mean_diff_all - mean_sum_all) + m_adj_all = mean_diff_all + n_adj_all + + # Denominator - fully vectorized using broadcasting + angle = torch.angle(complex_tensor) + + # For the denominator, we need to compute for each pair (i,j): + # x_sin = sin(alpha_i - m_adj[i,j]) + # y_sin = sin(alpha_j - n_adj[i,j]) + # where alpha_i has shape [e, f, t] for channel i + + # Expand angle dimensions: [e, f, i, t] -> [e, f, i, 1, t] and [e, f, 1, j, t] + angle_i = angle.unsqueeze(3) # [e, f, i, 1, t] + angle_j = angle.unsqueeze(2) # [e, f, 1, j, t] + + # Expand m_adj and n_adj: [e, f, i, j] -> [e, f, i, j, 1] + m_adj = m_adj_all.unsqueeze(-1) # [e, f, i, j, 1] + n_adj = n_adj_all.unsqueeze(-1) # [e, f, i, j, 1] + + # Compute sin terms with proper broadcasting: + # For each pair (i,j), subtract m_adj[i,j] from angle[i,:] and n_adj[i,j] from angle[j,:] + x_sin = torch.sin(angle_i - m_adj) # [e, f, i, j, t] - broadcasts [e,f,i,1,t] - [e,f,i,j,1] + y_sin = torch.sin(angle_j - n_adj) # [e, f, i, j, t] - broadcasts [e,f,1,j,t] - [e,f,i,j,1] + + # Sum over time dimension and compute denominator + x_sin_sq_sum = torch.sum(x_sin**2, dim=-1) # [e, f, i, j] + y_sin_sq_sum = torch.sum(y_sin**2, dim=-1) # [e, f, i, j] + + den = 2 * torch.sqrt(x_sin_sq_sum * y_sin_sq_sum) + + # Avoid division by zero + den = torch.where(den == 0, torch.ones_like(den), den) + + # Compute connectivity + con = num / den + con = con.permute(1, 0, 2, 3) # [n_freq, n_epochs, n_ch, n_ch] + + if epochs_average: + con = torch.nanmean(con, dim=1) + + # Convert back to numpy + return con.cpu().numpy() + + +@njit(nopython=False, parallel=True, cache=True) +def _accorr_den_calc(n_epochs, n_freq, n_ch_total, angle): + + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) + for i in prange(den.shape[2]): + for j in prange(i, den.shape[3]): + alpha1 = angle[:, :, i, :] + alpha2 = angle[:, :, j, :] + + phase_diff = alpha1 - alpha2 + phase_sum = alpha1 + alpha2 + + def axis2_mean(m): + return np.array([ + m[i, j, :].mean() + for i in prange(m.shape[0]) + for j in prange(m.shape[1]) + ]).reshape((m.shape[0], m.shape[1])) + + mean_diff = np.angle(axis2_mean(np.exp(1j * phase_diff))) + mean_sum = np.angle(axis2_mean(np.exp(1j * phase_sum))) + + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + x = alpha1.copy() + for xi in prange(x.shape[0]): + for xj in prange(x.shape[1]): + for xk in prange(x.shape[2]): + x[xi, xj, xk] -= m_adj[xi, xj] + x_sin = np.sin(x) + + y = alpha2.copy() + for yi in prange(y.shape[0]): + for yj in prange(y.shape[1]): + for yk in prange(y.shape[2]): + y[yi, yj, yk] -= n_adj[yi, yj] + y_sin = np.sin(y) + + den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) + den[:, :, i, j] = den_ij + den[:, :, j, i] = den_ij + + return den + +def _accorr_hybrid_numba(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using a hybrid approach. + + This function calculates the adjusted circular correlation coefficient between + all channel pairs. It uses a vectorized computation for the numerator and an + exact loop-based computation for the denominator. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed as: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (loop): For each channel pair, computes optimal phase centering + parameters (m_adj, n_adj) that minimize the denominator, then calculates + the normalization factor. + + This metric provides a more accurate measure of circular correlation by + adjusting the phase centering for each channel pair individually, rather than + using a global circular mean. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (loop) + angle = np.angle(complex_signal) + + den = _accorr_den_calc(n_epochs, n_freq, n_ch_total, angle) + + den = np.where(den == 0, 1, den) + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con + + +def _accorr_hybrid_vectorized(complex_signal: np.ndarray, epochs_average: bool = True, + show_progress: bool = True) -> np.ndarray: + """ + Computes Adjusted Circular Correlation using full vectorization. + + This function calculates the adjusted circular correlation coefficient between + all channel pairs using fully vectorized operations, eliminating nested loops + for significant performance improvements. + + Parameters + ---------- + complex_signal : np.ndarray + Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times) + Note: This is the already reshaped signal from compute_sync. + + epochs_average : bool, optional + If True, connectivity values are averaged across epochs (default) + If False, epoch-by-epoch connectivity is preserved + + show_progress : bool, optional + If True, display a progress bar during computation (default) + If False, no progress bar is shown + + Returns + ------- + con : np.ndarray + Adjusted circular correlation matrix with shape: + - If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels) + - If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels) + + Notes + ----- + The adjusted circular correlation is computed using full vectorization: + + 1. Numerator (vectorized): Uses the difference between the absolute values of + the conjugate product and the direct product of normalized complex signals. + + 2. Denominator (vectorized): Computes optimal phase centering parameters + (m_adj, n_adj) for all channel pairs simultaneously using broadcasting, + then calculates the normalization factor without loops. + + This fully vectorized approach provides significant speedup compared to the + loop-based implementation while maintaining numerical accuracy. + + References + ---------- + Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024). + Arbitrary methodological decisions skew inter-brain synchronization estimates + in hyperscanning-EEG studies. Imaging Neuroscience, 2. + https://doi.org/10.1162/imag_a_00350 + """ + n_epochs = complex_signal.shape[0] + n_freq = complex_signal.shape[1] + n_ch_total = complex_signal.shape[2] + n_times = complex_signal.shape[3] + + transpose_axes = (0, 1, 3, 2) + + # Numerator (vectorized) + z = complex_signal / np.abs(complex_signal) + c, s = np.real(z), np.imag(z) + + cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes) + r_minus = np.abs(cross_conj) + + cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes) + r_plus = np.abs(cross_prod) + + num = r_minus - r_plus + + # Denominator (fully vectorized) + angle = np.angle(complex_signal) + + # Expand dimensions for broadcasting + # angle shape: (n_epochs, n_freq, n_ch_total, n_times) + alpha1_all = angle[:, :, :, None, :] # (n_epochs, n_freq, n_ch_total, 1, n_times) + alpha2_all = angle[:, :, None, :, :] # (n_epochs, n_freq, 1, n_ch_total, n_times) + + # Compute phase differences and sums for all pairs + phase_diff = alpha1_all - alpha2_all # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + phase_sum = alpha1_all + alpha2_all # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + + mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=4, keepdims=True)) + mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=4, keepdims=True)) + + # Compute optimal phase centering parameters + n_adj = -1 * (mean_diff - mean_sum) / 2 + m_adj = mean_diff + n_adj + + # Compute sine deviations for all pairs + x_sin = np.sin(alpha1_all - m_adj) # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + y_sin = np.sin(alpha2_all - n_adj) # (n_epochs, n_freq, n_ch_total, n_ch_total, n_times) + + # Sum of squared sines + x_sin_sq_sum = np.sum(x_sin**2, axis=4) # (n_epochs, n_freq, n_ch_total, n_ch_total) + y_sin_sq_sum = np.sum(y_sin**2, axis=4) # (n_epochs, n_freq, n_ch_total, n_ch_total) + + # Compute denominator + den = 2 * np.sqrt(x_sin_sq_sum * y_sin_sq_sum) + + # Handle division by zero + den = np.where(den == 0, 1, den) + + # Compute connectivity + con = num / den + con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch + + if epochs_average: + con = np.nanmean(con, axis=1) + + return con diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..e1dee01 --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,109 @@ +""" +Tests for synchronization metrics, particularly adjusted circular correlation (accorr). + +All optimized implementations are tested against the unoptimized reference +implementation to ensure numerical correctness. +""" + +import numpy as np +import pytest + +from hypyp.sync.accorr import accorr, NUMBA_AVAILABLE, TORCH_AVAILABLE, MPS_AVAILABLE +from tests.hypyp.sync.accorr import accorr_reference + + +class TestAccorrReference: + """Basic properties of the reference implementation.""" + + def test_reference_shape_no_average(self, complex_signal): + result = accorr_reference(complex_signal, epochs_average=False, show_progress=False) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_freq, n_epochs, n_ch, n_ch) + + def test_reference_shape_with_average(self, complex_signal): + result = accorr_reference(complex_signal, epochs_average=True, show_progress=False) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_freq, n_ch, n_ch) + + def test_reference_value_range(self, complex_signal): + result = accorr_reference(complex_signal, epochs_average=True, show_progress=False) + assert np.all(result >= -1) and np.all(result <= 1) + assert not np.any(np.isnan(result)) + + def test_reference_symmetry(self, complex_signal): + result = accorr_reference(complex_signal, epochs_average=True, show_progress=False) + for freq_idx in range(result.shape[0]): + matrix = result[freq_idx] + np.testing.assert_allclose(matrix, matrix.T, rtol=1e-10, atol=1e-12) + + +class TestAccorrOptimizations: + """Optimized implementations must match reference.""" + + MPS_TOL = 1e-5 + + def test_default_vs_reference(self, complex_signal): + result_reference = accorr_reference(complex_signal, epochs_average=False, show_progress=False) + result_optimized = accorr(complex_signal, epochs_average=False, show_progress=False, optimization=None) + np.testing.assert_allclose(result_optimized, result_reference, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_numba_vs_reference(self, complex_signal): + result_reference = accorr_reference(complex_signal, epochs_average=False, show_progress=False) + result_numba = accorr(complex_signal, epochs_average=False, show_progress=False, optimization="numba") + np.testing.assert_allclose(result_numba, result_reference, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_torch_cpu_vs_reference(self, complex_signal): + result_reference = accorr_reference(complex_signal, epochs_average=False, show_progress=False) + result_torch = accorr(complex_signal, epochs_average=False, show_progress=False, optimization="torch_cpu") + np.testing.assert_allclose(result_torch, result_reference, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE or not MPS_AVAILABLE, reason="Torch or MPS not available") + def test_torch_mps_vs_reference(self, complex_signal): + result_reference = accorr_reference(complex_signal, epochs_average=False, show_progress=False) + result_torch_mps = accorr(complex_signal, epochs_average=False, show_progress=False, optimization="torch_mps") + np.testing.assert_allclose(result_torch_mps, result_reference, rtol=self.MPS_TOL, atol=self.MPS_TOL) + + +class TestAccorrFeatures: + """Specific feature behavior.""" + + def test_epochs_averaging(self, complex_signal): + result_avg = accorr(complex_signal, epochs_average=True, show_progress=False, optimization=None) + result_no_avg = accorr(complex_signal, epochs_average=False, show_progress=False, optimization=None) + + n_freq = complex_signal.shape[1] + n_ch = complex_signal.shape[2] + assert result_avg.shape == (n_freq, n_ch, n_ch) + + manual_avg = np.nanmean(result_no_avg, axis=1) + np.testing.assert_allclose(result_avg, manual_avg, rtol=1e-10, atol=1e-12) + + def test_epochs_averaging_matches_reference(self, complex_signal): + result_ref = accorr_reference(complex_signal, epochs_average=True, show_progress=False) + result_opt = accorr(complex_signal, epochs_average=True, show_progress=False, optimization=None) + np.testing.assert_allclose(result_opt, result_ref, rtol=1e-9, atol=1e-10) + + +class TestAccorrErrorHandling: + """Error handling.""" + + def test_invalid_optimization(self, complex_signal): + with pytest.raises(ValueError, match="Optimization parameter is none of the accepted"): + accorr(complex_signal, epochs_average=False, show_progress=False, optimization="invalid_option") + + @pytest.mark.skipif(NUMBA_AVAILABLE, reason="Test requires numba to be unavailable") + def test_numba_unavailable(self, complex_signal): + with pytest.raises(ValueError, match="Numba library not available"): + accorr(complex_signal, epochs_average=False, show_progress=False, optimization="numba") + + @pytest.mark.skipif(TORCH_AVAILABLE, reason="Test requires torch to be unavailable") + def test_torch_unavailable(self, complex_signal): + with pytest.raises(ValueError, match="Torch library not available"): + accorr(complex_signal, epochs_average=False, show_progress=False, optimization="torch_cpu") + + @pytest.mark.skipif(not TORCH_AVAILABLE and MPS_AVAILABLE, reason="Test requires MPS to be unavailable") + def test_mps_unavailable(self, complex_signal): + with pytest.raises(ValueError, match="MPS not available"): + accorr(complex_signal, epochs_average=False, show_progress=False, optimization="torch_mps") \ No newline at end of file diff --git a/tutorial/accorr_test.py b/tutorial/accorr_test.py new file mode 100644 index 0000000..8530bbf --- /dev/null +++ b/tutorial/accorr_test.py @@ -0,0 +1,292 @@ +import mne +import time +import numpy as np +import pickle +import json +import pandas as pd +import dfply as df +import seaborn as sns +from pathlib import Path +from collections import OrderedDict +from hypyp import analyses +from hypyp.sync.accorr import ( + _accorr_hybrid, + _accorr_hybrid_vectorized, + _accorr_hybrid_precompute, + _accorr_hybrid_precompute_numba, + _accorr_hybrid_numba, + _accorr_hybrid_precompute_torch, + _accorr_hybrid_precompute_torch_loop, + NUMBA_AVAILABLE +) +import numba + +""" +Benchmark script comparing different optimization approaches for the Adjusted Circular Correlation (accorr) metric. + +This script tests: +1. Original: Loop-based approach with tqdm progress bar +2. Vectorized: Full vectorization using NumPy broadcasting (higher memory usage) +3. Numba: JIT compilation with parallelization (if numba is installed) +4. Multiprocessing: Distributed computation across multiple processes + +Performance characteristics: +- Original (loop-based): Good baseline, low memory overhead, CPU-friendly +- Vectorized: High memory usage, slower due to large intermediate arrays +- Numba: Requires compilation on first run, can be fast with parallelization +- Multiprocessing: IPC overhead makes it slower for small datasets +""" + +# Define frequency bands as a dictionary +freq_bands = { + 'Alpha-Low': [7.5, 11], + 'Alpha-High': [11.5, 13] +} + +# Convert to an OrderedDict to keep the defined order +freq_bands = OrderedDict(freq_bands) +print('Frequency bands:', freq_bands) + +preproc_S1 = mne.read_epochs('preproc_S1.fif') +preproc_S2 = mne.read_epochs('preproc_S2.fif') + +sampling_rate = preproc_S1.info['sfreq'] + +scaling_results = {method: [] for method in ['original', 'numba']} +scaling_times = {method: [] for method in ['original', 'numba']} + +def numba_run(cpu_num, numba_f): + def f(complex_signal: np.ndarray, + epochs_average: bool = True, + show_progress: bool = True): + numba.set_num_threads(cpu_num) + return numba_f(complex_signal, epochs_average, show_progress) + return f + + +def torch_run(device, torch_f): + def f(complex_signal: np.ndarray, + epochs_average: bool = True, + show_progress: bool = True): + return torch_f(complex_signal, epochs_average, + show_progress, device=device) + return f + + +method_dict = { + 'original': _accorr_hybrid, + 'numba4': numba_run(4, _accorr_hybrid_numba), + 'numba8': numba_run(8, _accorr_hybrid_numba), + 'numba_precompute8': numba_run(8, _accorr_hybrid_precompute_numba), + 'precompute': _accorr_hybrid_precompute, + 'torch_precompute_mps': torch_run('mps', _accorr_hybrid_precompute_torch), + 'torch_precompute_cpu': torch_run('cpu', _accorr_hybrid_precompute_torch), + 'torch_loop_mps': torch_run('mps', _accorr_hybrid_precompute_torch_loop), + 'torch_loop_cpu': torch_run('cpu', _accorr_hybrid_precompute_torch_loop), +} + +numba_palette = sns.light_palette('C1', 3) +numba_precompute_palette = sns.light_palette('C4', 3) +torch_palette = sns.light_palette('C5', 4) +torch_loop_palette = sns.light_palette('C6', 4) +method_palette = OrderedDict({ + 'original': 'C0', + 'numba4': numba_palette[1], + 'numba8': numba_palette[2], + 'vectorized': 'C2', + 'precompute': 'C3', + 'numba_precompute8': numba_precompute_palette[2], + 'torch_precompute_cpu': torch_palette[1], + 'torch_precompute_mps': torch_palette[2], + 'torch_precompute_cuda': torch_palette[3], + 'torch_loop_cpu': torch_loop_palette[1], + 'torch_loop_mps': torch_loop_palette[2], +}) + +out_path = Path('accorr_benchmarks') + +def multiply_channels(epochs, i): + ch_names = [f'{ch}{x}' for ch in epochs.ch_names for x in range(i)] + n_info = mne.create_info(ch_names, epochs.info['sfreq']) + return mne.EpochsArray(np.concatenate([epochs.get_data()] * i, axis=1), info=n_info) + + +def benchmark(method, epoch_multiplier, channel_multiplier): + # Create scaled dataset by concatenating + expanded_preproc_S1 = multiply_channels(preproc_S1, channel_multiplier) + expanded_preproc_S2 = multiply_channels(preproc_S2, channel_multiplier) + + epochs_list_S1 = [expanded_preproc_S1.copy() for _ in range(epoch_multiplier)] + epochs_list_S2 = [expanded_preproc_S2.copy() for _ in range(epoch_multiplier)] + + preproc_S1_scaled = mne.concatenate_epochs(epochs_list_S1) + preproc_S2_scaled = mne.concatenate_epochs(epochs_list_S2) + + # Prepare data for connectivity analysis + data_inter = np.array([preproc_S1_scaled, preproc_S2_scaled]) + + complex_signal = analyses.compute_freq_bands( + data_inter, + sampling_rate, + freq_bands, + filter_length=int(sampling_rate), + l_trans_bandwidth=5.0, + h_trans_bandwidth=5.0 + ) + print(f"\n Testing {method}...") + try: + st = time.perf_counter() + result = analyses.compute_sync(complex_signal, mode=method_dict[method], + epochs_average=False) + et = time.perf_counter() + perf_time = et - st + print(f" Time: {perf_time:.4f}s") + + is_ok = None + max_diff = None + if method != 'original': + orig_path = out_path / f'original-e{epoch_multiplier}-c{channel_multiplier}_result.pkl' + if orig_path.is_file(): + with open(orig_path, 'rb') as f: + orig_result = pickle.load(f) + + # Compare result with orig_result + max_diff = np.max(np.abs(result - orig_result)) + is_ok = np.allclose(result, orig_result, rtol=1e-9, atol=1e-10) + + return { + 'method': method, + 'epoch_multiplier': int(epoch_multiplier), + 'channel_multiplier': int(channel_multiplier), + 'time': float(perf_time), + 'result': result, + 'is_ok': is_ok, + 'max_diff': float(max_diff) if max_diff is not None else None + } + + except Exception as e: + raise e + + +benchmark_configs = pd.DataFrame( + [ + [1, 1], + [2, 1], + [3, 1], + [5, 1], + [8, 1], + [10, 1], + [1, 2], + [1, 4], + [1, 8], + [3, 2], + [3, 4], + [3, 8], + ], + columns=['epoch_multiplier', 'channel_multiplier'] +) + +def task_calc_benchmarks(): + def benchmark_action(method, epoch_multiplier, channels_multiplier, targets): + res = benchmark(method, epoch_multiplier, channels_multiplier) + + with open(targets[1], 'wb') as f: + pickle.dump(res['result'], f) + + del res['result'] + with open(targets[0], 'w') as f: + print(res) + json.dump(res, f) + + if not out_path.is_dir(): + out_path.mkdir() + + for _, r in benchmark_configs.iterrows(): + for method in method_dict: + name = f'{method}-e{r["epoch_multiplier"]}-c{r["channel_multiplier"]}' + orig_name = f'original-e{r["epoch_multiplier"]}-c{r["channel_multiplier"]}' + json_out = out_path / (name + '.json') + result_out = out_path / (name + '_result.pkl') + yield { + 'name': name, + 'actions': [(benchmark_action, (method, r['epoch_multiplier'], r['channel_multiplier']))], + 'targets': [json_out] + ([result_out] * 1), + 'uptodate': [json_out.is_file()], + 'file_dep': ([str(result_out).replace(method, 'original')] * int(method != 'original')) + } + + +def get_perfs(): + res = [] + for p in out_path.glob('*.json'): + with open(p) as f: + j = json.load(f) + res.append(j) + + base_ch_num = len(preproc_S1.info.ch_names) + len(preproc_S2.info.ch_names) + base_epoch_num = preproc_S1.get_data().shape[0] + perfs = ( + pd.DataFrame.from_records(res) + >> df.mutate( + channels = df.X.channel_multiplier * base_ch_num, + epochs = df.X.epoch_multiplier * base_epoch_num, + ) + ) + return perfs + + +def task_summary_plots(): + def action(targets): + perfs = get_perfs() + + # Calculate speedup relative to original method + speedup_data = [] + for (ch, ep), group in perfs.groupby(['channels', 'epochs']): + original_time = group[group['method'] == 'original']['time'].values + if len(original_time) > 0: + original_time = original_time[0] + for _, row in group.iterrows(): + speedup = original_time / row['time'] + speedup_data.append({ + 'method': row['method'], + 'channels': ch, + 'epochs': ep, + 'speedup': speedup + }) + + speedup_df = pd.DataFrame(speedup_data) + + fg = sns.catplot( + speedup_df, + x='channels', + y='speedup', + col='epochs', + col_wrap=3, + hue='method', + hue_order=method_palette.keys(), + palette=method_palette, + kind='bar', + sharey=False + ) + fg.savefig(targets[0]) + + return { + 'actions': [action], + 'targets': [out_path / 'benchmark.pdf'], + 'uptodate': [False] + } + + +def task_bad_perfs(): + def action(targets): + bad_perfs = ( + get_perfs() + >> df.mask(df.X.is_ok == False) + ) + bad_perfs.to_csv(targets[0]) + + return { + 'actions': [action], + 'targets': [out_path / 'bad_perfs.csv'], + 'uptodate': [False] + } \ No newline at end of file