Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
## [Unreleased]

### Added
- New connectivity metric: Adjusted Circular Correlation (`accorr`) in `analyses.py`
- Hybrid implementation with vectorized numerator and exact denominator computation
- Progress bar support via `tqdm` for monitoring computation progress
- Available through `pair_connectivity()` and `compute_sync()` functions with `mode='accorr'`
- **New `hypyp.sync` module**: Modular architecture for connectivity metrics
- Extracted 9 connectivity metrics into separate classes: `PLV`, `CCorr`, `ACorr`, `Coh`, `ImCoh`, `PLI`, `WPLI`, `EnvCorr`, `PowCorr`
- `BaseMetric` abstract class for uniform interface across all metrics
- `get_metric(mode, backend)` function for easy metric instantiation
- Backend support infrastructure (numpy default, with future support for numba/torch)
- Helper functions: `multiply_conjugate`, `multiply_conjugate_time`, `multiply_product`

### Changed
- **BREAKING**: `accorr` metric now returns raw connectivity values with shape `(n_epoch, n_freq, 2*n_ch, 2*n_ch)` like all other metrics. The `swapaxes` and `epochs_average` operations are now handled by `compute_sync()` instead of being applied inside the metric.
- Refactored `compute_sync()` to use the new `hypyp.sync` module internally

### Deprecated
- `_multiply_conjugate()` in analyses.py - use `hypyp.sync.multiply_conjugate` instead (will be removed in 1.0.0)
- `_multiply_conjugate_time()` in analyses.py - use `hypyp.sync.multiply_conjugate_time` instead (will be removed in 1.0.0)
- `_multiply_product()` in analyses.py - use `hypyp.sync.multiply_product` instead (will be removed in 1.0.0)
- `_accorr_hybrid()` in analyses.py - use `hypyp.sync.ACorr` instead (will be removed in 1.0.0)

## [0.5.0b13] - 2025-09-18

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The **Hy**perscanning **Py**thon **P**ipeline
## Contributors

Original authors: Florence BRUN, Anaël AYROLLES, Phoebe CHEN, Amir DJALOVSKI, Yann BEAUXIS, Suzanne DIKKER, Guillaume DUMAS
New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice FORTIN, Ghazaleh RANJBARAN, Quentin MOREAU, Caitriona DOUGLAS, Franck PORTEOUS, Jonas MAGO, Juan C. AVENDANO, Julie BONNAIRE
New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice FORTIN, Ghazaleh RANJBARAN, Quentin MOREAU, Caitriona DOUGLAS, Franck PORTEOUS, Jonas MAGO, Juan C. AVENDANO, Julie BONNAIRE, Martín A. MIGUEL

## Installation

Expand Down
174 changes: 74 additions & 100 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
| date | 2020-03-18 |
"""

import warnings
import numpy as np
import scipy
import scipy.signal as signal
Expand All @@ -19,7 +20,7 @@
import statsmodels.stats.multitest
import copy
from collections import namedtuple
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Optional
import matplotlib.pyplot as plt
from tqdm import tqdm

Expand All @@ -30,6 +31,7 @@
from mne.time_frequency import EpochsSpectrum

from .mvarica import MVAR, connectivity_mvarica
from .sync import get_metric


def pow(epochs: mne.Epochs, fmin: float, fmax: float, n_fft: int, n_per_seg: int, epochs_average: bool) -> namedtuple:
Expand Down Expand Up @@ -435,63 +437,47 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int,
return result


def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True) -> np.ndarray:
def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True,
optimization: Optional[str] = None) -> np.ndarray:
"""
Computes frequency-domain connectivity measures from analytic signals.

This function calculates various connectivity metrics between all possible
channel pairs based on the input complex-valued signals.

Parameters
----------
complex_signal : np.ndarray
Complex analytic signals with shape (2, n_epochs, n_channels, n_freq_bins, n_times)

mode : str
Connectivity measure to compute. Options:
- 'envelope_corr': envelope correlation - correlation between signal envelopes
- 'pow_corr': power correlation - correlation between signal power
- 'envelope_corr' or 'envcorr': envelope correlation - correlation between signal envelopes
- 'pow_corr' or 'powcorr': power correlation - correlation between signal power
- 'plv': phase locking value - consistency of phase differences
- 'ccorr': circular correlation coefficient - circular statistic for phase coupling
- 'accorr': adjusted circular correlation - circular correlation with optimized phase centering
- 'coh': coherence - normalized cross-spectrum
- 'imaginary_coh': imaginary coherence - imaginary part of coherence (volume conduction resistant)
- 'imaginary_coh' or 'imcoh': imaginary coherence - imaginary part of coherence (volume conduction resistant)
- 'pli': phase lag index - asymmetry of phase difference distribution
- 'wpli': weighted phase lag index - weighted version of PLI with improved properties

epochs_average : bool, optional
If True, connectivity values are averaged across epochs (default)
If False, epoch-by-epoch connectivity is preserved

optimization : str, optional
Optimization strategy. May require extra dependencies.
Currently only available for 'accorr'. Options:
- None: standard numpy implementation (default)
- 'auto': best available (torch > numba > numpy)
- 'numba': numba JIT compilation (falls back to numpy if unavailable)
- 'torch': PyTorch with auto-detected GPU (falls back gracefully)

Returns
-------
con : np.ndarray
Connectivity matrix with shape:
- If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels)
- If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels)

Notes
-----
Mathematical formulations for each connectivity measure:

- PLV: |⟨e^(i(φₓ-φᵧ))⟩|
Measures consistency of phase differences across time

- Envelope correlation: corr(env(x), env(y))
Pearson correlation between signal envelopes

- Coherence: |⟨XY*⟩|²/(⟨|X|²⟩⟨|Y|²⟩)
Normalized cross-spectrum

- Imaginary coherence: |Im(⟨XY*⟩)|/√(⟨|X|²⟩⟨|Y|²⟩)
Takes only imaginary part which is less affected by volume conduction

- PLI: |⟨sign(Im(XY*))⟩|
Quantifies asymmetry in phase difference distribution

- wPLI: |⟨|Im(XY*)|sign(Im(XY*))⟩|/⟨|Im(XY*)|⟩
Weighted version that downweights phase differences near 0 or π


Raises
------
ValueError
Expand All @@ -504,74 +490,22 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
# calculate all epochs at once, the only downside is that the disk may not have enough space
complex_signal = complex_signal.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp)
transpose_axes = (0, 1, 3, 2)
if mode.lower() == 'plv':
phase = complex_signal / np.abs(complex_signal)
c = np.real(phase)
s = np.imag(phase)
dphi = _multiply_conjugate(c, s, transpose_axes=transpose_axes)
con = abs(dphi) / n_samp

elif mode.lower() == 'envelope_corr':
env = np.abs(complex_signal)
mu_env = np.mean(env, axis=3).reshape(n_epoch, n_freq, 2 * n_ch, 1)
env = env - mu_env
con = np.einsum('nilm,nimk->nilk', env, env.transpose(transpose_axes)) / \
np.sqrt(np.einsum('nil,nik->nilk', np.sum(env ** 2, axis=3), np.sum(env ** 2, axis=3)))

elif mode.lower() == 'pow_corr':
env = np.abs(complex_signal) ** 2
mu_env = np.mean(env, axis=3).reshape(n_epoch, n_freq, 2 * n_ch, 1)
env = env - mu_env
con = np.einsum('nilm,nimk->nilk', env, env.transpose(transpose_axes)) / \
np.sqrt(np.einsum('nil,nik->nilk', np.sum(env ** 2, axis=3), np.sum(env ** 2, axis=3)))

elif mode.lower() == 'coh':
c = np.real(complex_signal)
s = np.imag(complex_signal)
amp = np.abs(complex_signal) ** 2
dphi = _multiply_conjugate(c, s, transpose_axes=transpose_axes)
con = np.abs(dphi) / np.sqrt(np.einsum('nil,nik->nilk', np.nansum(amp, axis=3),
np.nansum(amp, axis=3)))

elif mode.lower() == 'imaginary_coh':
c = np.real(complex_signal)
s = np.imag(complex_signal)
amp = np.abs(complex_signal) ** 2
dphi = _multiply_conjugate(c, s, transpose_axes=transpose_axes)
con = np.abs(np.imag(dphi)) / np.sqrt(np.einsum('nil,nik->nilk', np.nansum(amp, axis=3),
np.nansum(amp, axis=3)))

elif mode.lower() == 'ccorr':
angle = np.angle(complex_signal)
mu_angle = circmean(angle, high=np.pi, low=-np.pi, axis=3).reshape(n_epoch, n_freq, 2 * n_ch, 1)
angle = np.sin(angle - mu_angle)

formula = 'nilm,nimk->nilk'
con = np.abs(np.einsum(formula, angle, angle.transpose(transpose_axes)) /
np.sqrt(np.einsum('nil,nik->nilk', np.sum(angle ** 2, axis=3),
np.sum(angle ** 2, axis=3))))

elif mode.lower() == 'accorr':
con = _accorr_hybrid(complex_signal, epochs_average=epochs_average, show_progress=True)
return con

elif mode.lower() == 'pli':
c = np.real(complex_signal)
s = np.imag(complex_signal)
dphi = _multiply_conjugate_time(c, s, transpose_axes=transpose_axes)
con = abs(np.mean(np.sign(np.imag(dphi)), axis=4))

elif mode.lower() == 'wpli':
c = np.real(complex_signal)
s = np.imag(complex_signal)
dphi = _multiply_conjugate_time(c, s, transpose_axes=transpose_axes)
con_num = abs(np.mean(abs(np.imag(dphi)) * np.sign(np.imag(dphi)), axis=4))
con_den = np.mean(abs(np.imag(dphi)), axis=4)
con_den[con_den == 0] = 1
con = con_num / con_den

else:
raise ValueError('Metric type not supported.')
# Normalize mode names (handle aliases)
mode_lower = mode.lower()
mode_map = {
'envelope_corr': 'envcorr',
'pow_corr': 'powcorr',
'imaginary_coh': 'imcoh',
}
mode_normalized = mode_map.get(mode_lower, mode_lower)

# Get the metric from the sync module
try:
metric = get_metric(mode_normalized, optimization=optimization)
con = metric.compute(complex_signal, n_samp, transpose_axes)
except ValueError:
raise ValueError(f'Metric type "{mode}" not supported.')

con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch
if epochs_average:
Expand Down Expand Up @@ -1082,6 +1016,10 @@ def _multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tupl
"""
Computes the product of a complex array and its conjugate efficiently.

.. deprecated:: 0.5.0
This function is deprecated and will be removed in version 1.0.0.
Use :func:`hypyp.sync.multiply_conjugate` instead.

This helper function performs matrix multiplication between complex arrays
represented by their real and imaginary parts, collapsing the last dimension.

Expand All @@ -1108,6 +1046,12 @@ def _multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tupl

Using einsum for efficient computation without explicitly creating complex arrays.
"""
warnings.warn(
"_multiply_conjugate is deprecated and will be removed in version 1.0.0. "
"Use hypyp.sync.multiply_conjugate instead.",
DeprecationWarning,
stacklevel=2
)

formula = 'jilm,jimk->jilk'
product = np.einsum(formula, real, real.transpose(transpose_axes)) + \
Expand All @@ -1123,6 +1067,10 @@ def _multiply_conjugate_time(real: np.ndarray, imag: np.ndarray, transpose_axes:
"""
Computes the product of a complex array and its conjugate without collapsing time dimension.

.. deprecated:: 0.5.0
This function is deprecated and will be removed in version 1.0.0.
Use :func:`hypyp.sync.multiply_conjugate_time` instead.

Similar to _multiply_conjugate, but preserves the time dimension, which is
needed for certain connectivity metrics like wPLI.

Expand Down Expand Up @@ -1151,6 +1099,12 @@ def _multiply_conjugate_time(real: np.ndarray, imag: np.ndarray, transpose_axes:
computing metrics that require individual time point values rather than
time-averaged products.
"""
warnings.warn(
"_multiply_conjugate_time is deprecated and will be removed in version 1.0.0. "
"Use hypyp.sync.multiply_conjugate_time instead.",
DeprecationWarning,
stacklevel=2
)
formula = 'jilm,jimk->jilkm'
product = np.einsum(formula, real, real.transpose(transpose_axes)) + \
np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \
Expand All @@ -1165,6 +1119,10 @@ def _multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple)
"""
Computes the product of two complex arrays (not conjugate) efficiently.

.. deprecated:: 0.5.0
This function is deprecated and will be removed in version 1.0.0.
Use :func:`hypyp.sync.multiply_product` instead.

This helper function performs matrix multiplication between complex arrays
represented by their real and imaginary parts, collapsing the last dimension.
Unlike _multiply_conjugate, this computes z1 * z2 instead of z1 * conj(z2).
Expand Down Expand Up @@ -1193,6 +1151,12 @@ def _multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple)
Using einsum for efficient computation without explicitly creating complex arrays.
This is used in the adjusted circular correlation (accorr) metric.
"""
warnings.warn(
"_multiply_product is deprecated and will be removed in version 1.0.0. "
"Use hypyp.sync.multiply_product instead.",
DeprecationWarning,
stacklevel=2
)
formula = 'jilm,jimk->jilk'
product = np.einsum(formula, real, real.transpose(transpose_axes)) - \
np.einsum(formula, imag, imag.transpose(transpose_axes)) + 1j * \
Expand All @@ -1208,6 +1172,10 @@ def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True,
"""
Computes Adjusted Circular Correlation using a hybrid approach.

.. deprecated:: 0.5.0
This function is deprecated and will be removed in version 1.0.0.
Use :class:`hypyp.sync.ACCorr` instead.

This function calculates the adjusted circular correlation coefficient between
all channel pairs. It uses a vectorized computation for the numerator and an
exact loop-based computation for the denominator.
Expand Down Expand Up @@ -1255,6 +1223,12 @@ def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True,
in hyperscanning-EEG studies. Imaging Neuroscience, 2.
https://doi.org/10.1162/imag_a_00350
"""
warnings.warn(
"_accorr_hybrid is deprecated and will be removed in version 1.0.0. "
"Use hypyp.sync.ACCorr instead.",
DeprecationWarning,
stacklevel=2
)
n_epochs = complex_signal.shape[0]
n_freq = complex_signal.shape[1]
n_ch_total = complex_signal.shape[2]
Expand Down
Loading