Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice

`pip install HyPyP`

> For installation with dependencies optimized algorithms, see (**Poetry installation** below).

## Documentation

HyPyP documentation of all the API functions is available online at [hypyp.readthedocs.io](https://hypyp.readthedocs.io/)
Expand Down Expand Up @@ -51,6 +53,7 @@ For getting started with HyPyP, we have designed a little walkthrough: [getting_

📊 [shiny/\*.py](https://github.com/ppsp-team/HyPyP/blob/master/hypyp/app) — Shiny dashboards, install using `poetry install --extras shiny` (Patrice)


## Poetry Installation (Only for Developers and Adventurous Users)

To develop HyPyP, we recommend using [Poetry 2.x](https://python-poetry.org/). Follow these steps:
Expand Down Expand Up @@ -81,6 +84,26 @@ To install development dependencies, you can run:
poetry install --with dev
```

#### 3.1 Installing optimizations

To use `numba` optimizations, you can run:

```bash
poetry install --with optim_numba
```

To use `torch` optimizations, you can run:

```bash
poetry install --with optim_torch
```

You can also update dependencies on a pre-installed HyPyP with:

```bash
poetry sync --with optim_torch
```

### 4. Launch Jupyter Lab to Run Notebooks:

Instead of entering a shell, launch Jupyter Lab directly within the Poetry environment:
Expand Down
262 changes: 19 additions & 243 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import statsmodels.stats.multitest
import copy
from collections import namedtuple
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Optional
import matplotlib.pyplot as plt
from tqdm import tqdm
from hypyp.sync.accorr import accorr
from hypyp.sync.utils import (
_multiply_conjugate,
_multiply_conjugate_time,
_multiply_product
)

plt.ion()

Expand Down Expand Up @@ -435,7 +441,8 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int,
return result


def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True) -> np.ndarray:
def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True,
optimization: Optional[str] = None) -> np.ndarray:
"""
Computes frequency-domain connectivity measures from analytic signals.

Expand Down Expand Up @@ -463,6 +470,12 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
If True, connectivity values are averaged across epochs (default)
If False, epoch-by-epoch connectivity is preserved

optimization : str, optional
Allows using optimization strategies. May require extra dependencies.
See README for installation instructions.
Only available for 'accorr'. See sync.accorr.accorr for description of
the optimization options and related dependencies.

Returns
-------
con : np.ndarray
Expand Down Expand Up @@ -504,6 +517,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
# calculate all epochs at once, the only downside is that the disk may not have enough space
complex_signal = complex_signal.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp)
transpose_axes = (0, 1, 3, 2)

if mode.lower() == 'plv':
phase = complex_signal / np.abs(complex_signal)
c = np.real(phase)
Expand Down Expand Up @@ -552,7 +566,8 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
np.sum(angle ** 2, axis=3))))

elif mode.lower() == 'accorr':
con = _accorr_hybrid(complex_signal, epochs_average=epochs_average, show_progress=True)
con = accorr(complex_signal, epochs_average=epochs_average,
show_progress=True, optimization=optimization)
return con

elif mode.lower() == 'pli':
Expand Down Expand Up @@ -1074,243 +1089,4 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray],
else:
data = 'Please specify a valid mode: power, phase, xwt, or wtc.'
print(data)
return data


# helper function
def _multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray:
"""
Computes the product of a complex array and its conjugate efficiently.

This helper function performs matrix multiplication between complex arrays
represented by their real and imaginary parts, collapsing the last dimension.

Parameters
----------
real : np.ndarray
Real part of the complex array

imag : np.ndarray
Imaginary part of the complex array

transpose_axes : tuple
Axes to transpose for matrix multiplication

Returns
-------
product : np.ndarray
Product of the array and its complex conjugate

Notes
-----
This function implements the formula:
product = (real × real.T + imag × imag.T) - i(real × imag.T - imag × real.T)

Using einsum for efficient computation without explicitly creating complex arrays.
"""

formula = 'jilm,jimk->jilk'
product = np.einsum(formula, real, real.transpose(transpose_axes)) + \
np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \
(np.einsum(formula, real, imag.transpose(transpose_axes)) - \
np.einsum(formula, imag, real.transpose(transpose_axes)))

return product


# helper function
def _multiply_conjugate_time(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray:
"""
Computes the product of a complex array and its conjugate without collapsing time dimension.

Similar to _multiply_conjugate, but preserves the time dimension, which is
needed for certain connectivity metrics like wPLI.

Parameters
----------
real : np.ndarray
Real part of the complex array

imag : np.ndarray
Imaginary part of the complex array

transpose_axes : tuple
Axes to transpose for matrix multiplication

Returns
-------
product : np.ndarray
Product of the array and its complex conjugate with time dimension preserved

Notes
-----
This function uses a different einsum formula than _multiply_conjugate:
'jilm,jimk->jilkm' instead of 'jilm,jimk->jilk'

This preserves the time dimension (m) in the output, which is necessary for
computing metrics that require individual time point values rather than
time-averaged products.
"""
formula = 'jilm,jimk->jilkm'
product = np.einsum(formula, real, real.transpose(transpose_axes)) + \
np.einsum(formula, imag, imag.transpose(transpose_axes)) - 1j * \
(np.einsum(formula, real, imag.transpose(transpose_axes)) - \
np.einsum(formula, imag, real.transpose(transpose_axes)))

return product


# helper function
def _multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray:
"""
Computes the product of two complex arrays (not conjugate) efficiently.

This helper function performs matrix multiplication between complex arrays
represented by their real and imaginary parts, collapsing the last dimension.
Unlike _multiply_conjugate, this computes z1 * z2 instead of z1 * conj(z2).

Parameters
----------
real : np.ndarray
Real part of the complex array

imag : np.ndarray
Imaginary part of the complex array

transpose_axes : tuple
Axes to transpose for matrix multiplication

Returns
-------
product : np.ndarray
Product of the array with itself (non-conjugate)

Notes
-----
This function implements the formula for z1 * z2:
product = (real × real.T - imag × imag.T) + i(real × imag.T + imag × real.T)

Using einsum for efficient computation without explicitly creating complex arrays.
This is used in the adjusted circular correlation (accorr) metric.
"""
formula = 'jilm,jimk->jilk'
product = np.einsum(formula, real, real.transpose(transpose_axes)) - \
np.einsum(formula, imag, imag.transpose(transpose_axes)) + 1j * \
(np.einsum(formula, real, imag.transpose(transpose_axes)) + \
np.einsum(formula, imag, real.transpose(transpose_axes)))

return product


# helper function
def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True,
show_progress: bool = True) -> np.ndarray:
"""
Computes Adjusted Circular Correlation using a hybrid approach.

This function calculates the adjusted circular correlation coefficient between
all channel pairs. It uses a vectorized computation for the numerator and an
exact loop-based computation for the denominator.

Parameters
----------
complex_signal : np.ndarray
Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times)
Note: This is the already reshaped signal from compute_sync.

epochs_average : bool, optional
If True, connectivity values are averaged across epochs (default)
If False, epoch-by-epoch connectivity is preserved

show_progress : bool, optional
If True, display a progress bar during computation (default)
If False, no progress bar is shown

Returns
-------
con : np.ndarray
Adjusted circular correlation matrix with shape:
- If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels)
- If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels)

Notes
-----
The adjusted circular correlation is computed as:

1. Numerator (vectorized): Uses the difference between the absolute values of
the conjugate product and the direct product of normalized complex signals.

2. Denominator (loop): For each channel pair, computes optimal phase centering
parameters (m_adj, n_adj) that minimize the denominator, then calculates
the normalization factor.

This metric provides a more accurate measure of circular correlation by
adjusting the phase centering for each channel pair individually, rather than
using a global circular mean.

References
----------
Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024).
Arbitrary methodological decisions skew inter-brain synchronization estimates
in hyperscanning-EEG studies. Imaging Neuroscience, 2.
https://doi.org/10.1162/imag_a_00350
"""
n_epochs = complex_signal.shape[0]
n_freq = complex_signal.shape[1]
n_ch_total = complex_signal.shape[2]

transpose_axes = (0, 1, 3, 2)

# Numerator (vectorized)
z = complex_signal / np.abs(complex_signal)
c, s = np.real(z), np.imag(z)

cross_conj = _multiply_conjugate(c, s, transpose_axes=transpose_axes)
r_minus = np.abs(cross_conj)

cross_prod = _multiply_product(c, s, transpose_axes=transpose_axes)
r_plus = np.abs(cross_prod)

num = r_minus - r_plus

# Denominator (loop)
angle = np.angle(complex_signal)
den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total))

total_pairs = (n_ch_total * (n_ch_total + 1)) // 2
pbar = tqdm(total=total_pairs, desc=" accorr (denominator)",
disable=not show_progress, leave=False)

for i in range(n_ch_total):
for j in range(i, n_ch_total):
alpha1 = angle[:, :, i, :]
alpha2 = angle[:, :, j, :]

phase_diff = alpha1 - alpha2
phase_sum = alpha1 + alpha2

mean_diff = np.angle(np.mean(np.exp(1j * phase_diff), axis=2, keepdims=True))
mean_sum = np.angle(np.mean(np.exp(1j * phase_sum), axis=2, keepdims=True))

n_adj = -1 * (mean_diff - mean_sum) / 2
m_adj = mean_diff + n_adj

x_sin = np.sin(alpha1 - m_adj)
y_sin = np.sin(alpha2 - n_adj)

den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2))
den[:, :, i, j] = den_ij
den[:, :, j, i] = den_ij

pbar.update(1)

pbar.close()

den = np.where(den == 0, 1, den)
con = num / den
con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch

if epochs_average:
con = np.nanmean(con, axis=1)

return con
return data
Loading