diff --git a/.gitignore b/.gitignore index 84cbb7da..868b9da4 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ data2_/ alignment_output_data/ alignment_output_data2/ gen_modules/ +examples/brain_plotting/fsaverage/ +examples/brain_plotting/*.html +docs/sg_execution_times.rst \ No newline at end of file diff --git a/.readthedocs.yml b/.readthedocs.yml index 123aa144..889d7373 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" # You can also specify other tool versions: # nodejs: "16" # rust: "1.55" diff --git a/docs/references/encoding.rst b/docs/references/encoding.rst index 91c9383c..7bd5a709 100644 --- a/docs/references/encoding.rst +++ b/docs/references/encoding.rst @@ -7,8 +7,18 @@ TRF --- .. autoclass:: TRF - :members: - :exclude-members: get_params, set_params + :members: + :exclude-members: get_params, set_params .. minigallery:: naplib.encoding.TRF - :add-heading: Examples using ``TRF `` \ No newline at end of file + :add-heading: Examples using ``TRF`` + +BandedTRF +--------- + +.. autoclass:: BandedTRF + :members: + :exclude-members: get_params, set_params + +.. minigallery:: naplib.encoding.BandedTRF + :add-heading: Examples using ``BandedTRF`` \ No newline at end of file diff --git a/docs/references/stats.rst b/docs/references/stats.rst index ff03a027..68ba3b3b 100644 --- a/docs/references/stats.rst +++ b/docs/references/stats.rst @@ -3,13 +3,21 @@ Stats .. currentmodule:: naplib.stats +Correlation +----------- + +.. autofunction:: pairwise_correlation + +.. minigallery:: naplib.stats.pairwise_correlation + :add-heading: Examples using ``pairwise_correlation`` + T-Test Responsive Electrodes ---------------------------- .. autofunction:: responsive_ttest .. minigallery:: naplib.stats.responsive_ttest - :add-heading: Examples using ``responsive_ttest `` + :add-heading: Examples using ``responsive_ttest`` T-Test with Feature Control --------------------------- @@ -21,14 +29,18 @@ Discriminability .. autofunction:: discriminability +.. autofunction:: wilks_lambda_discriminability + +.. autofunction:: lda_discriminability + Linear Mixed Effects Model -------------------------- .. autoclass:: LinearMixedEffectsModel - :members: + :members: .. minigallery:: naplib.stats.LinearMixedEffectsModel - :add-heading: Examples using ``LinearMixedEffectsModel `` + :add-heading: Examples using ``LinearMixedEffectsModel`` Stars for P-Values ------------------ diff --git a/docs/requirements.txt b/docs/requirements.txt index 491e35e5..b6193ef5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,16 @@ -sphinx>=4.2.0 -sphinx_rtd_theme>=1.0.0 -ipython>=7.4 -ipykernel>=5.1.0 -numpydoc>=1.1.0 -recommonmark==0.5.0 -sphinx-gallery==0.10.1 -mne<1.5 # for building docs, since mne-bids is needed, must have lower version of mne https://github.com/mne-tools/mne-python/pull/11582/files -openneuro-py==2022.4.0 -mne-bids==0.11.1 -nbformat>=4.2.0 +# Documentation Core +sphinx>=7.0.0 +sphinx_rtd_theme>=2.0.0 +numpydoc>=1.6.0 +myst-parser>=2.0.0 # Modern replacement for recommonmark + +# Execution & Gallery +ipython>=8.0 +ipykernel>=6.0 +sphinx-gallery>=0.15.0 # Necessary for modern MNE compatibility +nbformat>=5.0 + +# Neural Data Science Stack +mne>=1.6.0 # Works with OpenNeuro 2026 +mne-bids>=0.14.0 # Fixes the versioning conflict with MNE 1.5+ +openneuro-py==2026.1.0 # Your requested version \ No newline at end of file diff --git a/examples/banded_ridge_TRF_fitting/README.rst b/examples/banded_ridge_TRF_fitting/README.rst new file mode 100644 index 00000000..82c3e8cc --- /dev/null +++ b/examples/banded_ridge_TRF_fitting/README.rst @@ -0,0 +1,2 @@ +Fitting Banded Ridge TRF Models +------------------------------- \ No newline at end of file diff --git a/examples/banded_ridge_TRF_fitting/plot_banded_trf_comparison.py b/examples/banded_ridge_TRF_fitting/plot_banded_trf_comparison.py new file mode 100644 index 00000000..c9cb1ba9 --- /dev/null +++ b/examples/banded_ridge_TRF_fitting/plot_banded_trf_comparison.py @@ -0,0 +1,265 @@ +""" +=========================================================== +TRF Comparison: Iterative RidgeCV vs. Banded Regularization +=========================================================== + +This example compares two approaches for encoding models with multiple +stimulus features: + +1. **Iterative Standard TRF**: Adds features sequentially, optimizing a + single global regularization parameter (alpha) via cross-validation. +2. **Banded TRF**: Adds features sequentially, but optimizes a unique + alpha for each feature band. + +The comparison focuses on predictive accuracy ($R$), marginal improvement ($Delta R$), +and the model's ability to ignore irrelevant noise. +""" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.signal import resample +from scipy.stats import zscore, ttest_1samp +import naplib as nl +from naplib.encoding import TRF, BandedTRF +from sklearn.linear_model import Ridge +from mne.decoding.receptive_field import _delay_time_series + +############################################################################### +# 1. Prepare Synthetic Data +# ------------------------- +# We load speech task data and compute the auditory envelope and peak rate. +# A "noise" feature is added to test regularization robustness. + +data = nl.io.load_speech_task_data() +n_trials = 3 +data = data[:n_trials] +feat_fs = 100 + +# Preprocess features +data['aud_spec'] = [resample(nl.features.auditory_spectrogram(trl['sound'], 11025), trl['resp'].shape[0], axis=0) for trl in data] +data['env'] = [zscore(np.sum(trl['aud_spec'], axis=1)) for trl in data] +data['peak_rate'] = [nl.features.peak_rate(trl['aud_spec'], feat_fs) for trl in data] + +# Inject Noise Band: Scaled to match the variance of the envelope +np.random.seed(42) +for i in range(len(data)): + noise = np.random.randn(data[i]['resp'].shape[0]) + data[i]['noise'] = (noise / np.std(noise)) * np.std(data[i]['env']) + +tmin, tmax, sfreq = -0.1, 0.4, 100 +feature_list = ['env', 'noise', 'peak_rate'] +alphas = np.logspace(-2, 5, 15) + +############################################################################### +# 2. Fit Standard TRF with Alpha Path Tracking +# -------------------------------------------- +# We simulate a "Standard" TRF approach by finding a single optimal alpha for +# the combined feature matrix using leave-one-trial-out cross-validation. + +print("Fitting Standard TRF & Tracking Alpha Path...") +standard_p = [] +standard_total_r = [] +standard_delta_r = [] +standard_alpha_paths = [] +prev_r = 0 +prev_r_all = 0 + +for i in range(len(feature_list)): + current_feats = feature_list[:i+1] + all_X = [] + for trl in data: + curr_X = [trl[ft][:, np.newaxis] if trl[ft].ndim == 1 else trl[ft] for ft in current_feats] + curr_X = np.concatenate(curr_X, axis=1) + curr_X = _delay_time_series(curr_X, tmin, tmax, sfreq, fill_mean=False) + curr_X = curr_X.reshape(curr_X.shape[0], -1) + all_X.append(curr_X) + + y = data['resp'] + path_for_this_set = [] + best_alpha_r = -np.inf + + for alpha in alphas: + trial_betas = [Ridge(alpha=alpha).fit(tx, ty).coef_ for tx, ty in zip(all_X, y)] + loto_trial_rs = [] + for t_idx in range(n_trials): + other_indices = [idx for idx in range(n_trials) if idx != t_idx] + avg_coef = np.mean([trial_betas[idx] for idx in other_indices], axis=0) + y_hat = (all_X[t_idx]/alpha) @ avg_coef.T + r = nl.stats.pairwise_correlation(y[t_idx], y_hat) + loto_trial_rs.append(np.mean(r)) + + alpha_r = np.array(loto_trial_rs) + avg_alpha_r = np.mean(loto_trial_rs) + path_for_this_set.append(avg_alpha_r) + + if avg_alpha_r > best_alpha_r: + best_alpha_r = avg_alpha_r + best_alpha_r_all = alpha_r + final_best_model = np.stack(trial_betas, axis=2) + _, p_val = ttest_1samp(alpha_r-prev_r_all, 0) + + standard_alpha_paths.append(path_for_this_set) + standard_total_r.append(best_alpha_r) + standard_delta_r.append(best_alpha_r - prev_r) + standard_p.append(p_val) + prev_r = best_alpha_r + prev_r_all = best_alpha_r_all + +############################################################################### +# 3. Fit Banded TRF +# ----------------- +# The BandedTRF model allows each feature band to have its own optimal +# regularization parameter, determined sequentially. + +print("Fitting Banded TRF...") +banded_model = BandedTRF(tmin=tmin, tmax=tmax, sfreq=sfreq, alphas=alphas) +banded_model.fit(data=data, feature_order=feature_list, target='resp') + +df_summary = banded_model.summary() + +############################################################################### +# 4a. Comprehensive Comparison Plots & Statistics +# ----------------------------------------------- +# Here we compare the cumulative correlation and marginal improvement. + +# Print Statistics for Standard Model +print("\n" + "="*30) +print("STANDARD TRF STATISTICS") +print("="*30) +for i, feat in enumerate(feature_list): + print(f"Feature: {feat:10} | Delta R: {standard_delta_r[i]:.4f} | Significance p: {standard_p[i]:.4f}") + +# Print Statistics for Banded Model +print("\n" + "="*30) +print("BANDED TRF STATISTICS") +print("="*30) +print(df_summary) + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +# Comparison A: Cumulative Predictive Accuracy +banded_cumulative_r = [banded_model.scores_[:,:,i].mean() for i in range(len(feature_list))] +axes[0].plot(feature_list, standard_total_r, 'o--', label='Standard (RidgeCV)', color='#7f7f7f', markersize=8) +axes[0].plot(feature_list, banded_cumulative_r, 'D-', label='Banded TRF', color='#1f77b4', markersize=8) +axes[0].set_title(r'Cumulative Predictive Accuracy ($R$)', fontweight='bold') +axes[0].set_ylabel('Mean Pearson Correlation') +axes[0].set_xlabel('Feature Set (Cumulative)') +axes[0].legend() +axes[0].grid(axis='y', alpha=0.3) + +# Comparison B: Delta R (Unique Variance) +x = np.arange(len(feature_list)) +width = 0.35 +axes[1].bar(x - width/2, standard_delta_r, width, label=r'Standard $\Delta R$', color='#aaaaaa') +axes[1].bar(x + width/2, df_summary['Delta R'], width, label=r'Banded $\Delta R$', color='#d62728') +axes[1].set_xticks(x) +axes[1].set_xticklabels(feature_list) +axes[1].set_title(r'Marginal Improvement ($\Delta R$)', fontweight='bold') +axes[1].set_ylabel(r'Improvement in $R$') +axes[1].set_yscale('symlog', linthresh=1e-4) +axes[1].legend() + +plt.tight_layout() +plt.show() + +############################################################################### +# 4b. Visualization: Alpha Optimization Paths (Standard vs. Banded) +# ----------------------------------------------------------------- +# We compare the optimization curves for each feature. For the Standard model, +# the path represents the best $R$ achievable using a global $\alpha$ as +# features are added. For the Banded model, the path represents the marginal +# improvement ($\Delta R$) gained by optimizing that specific band's alpha. + +colors = {'env': '#1f77b4', 'noise': '#7f7f7f', 'peak_rate': '#d62728'} + +for b_idx, feat in enumerate(feature_list): + fig, axes = plt.subplots(1, 2, figsize=(14, 4), sharey=True) + + # --- Left Plot: Standard TRF (Global Alpha) --- + # In the standard approach, we look at the R-path for the cumulative set + std_path = np.array(standard_alpha_paths[b_idx]) + # Calculate marginal improvement for standard model + prev_std_r = 0 if b_idx == 0 else standard_total_r[b_idx-1] + std_delta_path = std_path - prev_std_r + + best_std_idx = np.argmax(std_delta_path) + axes[0].semilogx(alphas, std_delta_path, 'o-', color='black', alpha=0.6, label=f'Global $\\alpha$ Path') + axes[0].plot(alphas[best_std_idx], std_delta_path[best_std_idx], '*', + markersize=14, markeredgecolor='k', label=f'Opt $\\alpha$: {alphas[best_std_idx]:.1e}') + + axes[0].set_title(f'Standard TRF - Step {b_idx+1}: {feat}') + axes[0].set_xlabel(r'Global Regularization ($\alpha$)') + axes[0].set_ylabel(r'Marginal Improvement ($\Delta R$)') + axes[0].legend(fontsize='small') + + # --- Right Plot: Banded TRF (Independent Alpha) --- + # In the banded approach, we look at the R-path for the specific feature band + banded_path = banded_model.alpha_paths_[feat] + # Calculate marginal improvement relative to previous bands' max R + prev_banded_r = 0 if b_idx == 0 else np.max(banded_model.alpha_paths_[feature_list[b_idx-1]]) + banded_delta_path = banded_path - prev_banded_r + + best_banded_alpha = banded_model.feature_alphas_[feat] + peak_banded_delta = np.max(banded_delta_path) + + axes[1].semilogx(alphas, banded_delta_path, 'o-', color=colors[feat], label=f'Band: {feat}') + axes[1].plot(best_banded_alpha, peak_banded_delta, '*', + markersize=14, markeredgecolor='k', label=f'Opt $\\alpha$: {best_banded_alpha:.1e}') + + axes[1].set_title(f'Banded TRF - Step {b_idx+1}: {feat}') + axes[1].set_xlabel(r'Band-Specific Regularization ($\alpha$)') + axes[1].legend(fontsize='small') + + all_deltas = np.concatenate([std_delta_path, banded_delta_path]) + ymax = all_deltas.max() + ymin = max(all_deltas.min(), -0.005) + axes[0].set_ylim([ymin, ymax+(ymax-ymin)*0.1]) + + plt.tight_layout() + plt.show() + +############################################################################### +# 5. Kernel Comparison: Standard vs. Banded +# ----------------------------------------- +# Inspecting the kernels reveals how Banded TRF better suppresses the +# noise feature by applying an independent regularization penalty. + +best_ch = 0 +lags = np.linspace(tmin, tmax, banded_model._ndelays) +fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True) + +# Extract Standard TRF Kernels +std_coef = final_best_model[best_ch, :, :].reshape(len(feature_list), len(lags), n_trials) + +# Extract Banded TRF Kernels (average across trials) +banded_coef = banded_model.coef_[best_ch] + +colors = ['#1f77b4', '#7f7f7f', '#d62728'] + +for i, feat in enumerate(feature_list): + # Plot TRF with error shading across trials/CV folds + nl.visualization.shaded_error_plot( + lags, std_coef[i, :], + color=colors[i], + ax=axes[0], + plt_args={'label': f'Std: {feat}', 'lw': 2} + ) + nl.visualization.shaded_error_plot( + lags, banded_coef[i, :], + color=colors[i], + ax=axes[1], + plt_args={'label': f'Banded: {feat}', 'lw': 2} + ) + +axes[0].set_title(f'Standard TRF Kernels (Global $\\alpha$)\nChannel {best_ch}') +axes[1].set_title(f'Banded TRF Kernels (Independent $\\alpha$)\nChannel {best_ch}') + +for ax in axes: + ax.axhline(0, color='black', lw=1, alpha=0.5) + ax.set_xlabel('Lag (s)') + ax.legend(fontsize='small', frameon=False) + +axes[0].set_ylabel('Weights (a.u.)') +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/examples/banded_ridge_TRF_fitting/plot_banded_trf_optimization.py b/examples/banded_ridge_TRF_fitting/plot_banded_trf_optimization.py new file mode 100644 index 00000000..bdc36453 --- /dev/null +++ b/examples/banded_ridge_TRF_fitting/plot_banded_trf_optimization.py @@ -0,0 +1,182 @@ +r""" +=================================================== +Banded Ridge: Robustness Check with Null Bands +=================================================== + +This example provides a rigorous sanity check for BandedTRF. We insert a +"Null Band" (random Gaussian noise) between our meaningful features to +ensure the model correctly regularizes irrelevant information. + +Robustness Checks included: +1. Stimulus Alignment Visualization. +2. Step-wise Marginal Delta R optimization paths. +3. Order-invariance consistency (Scatter of Order 1 vs Order 2). +4. Kernel weight inspection for noise suppression. +5. Statistical significance via the .summary() method. +""" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.signal import resample +from scipy.stats import zscore +import naplib as nl +from naplib.encoding import BandedTRF + +############################################################################### +# 1. Prepare the Data +# ------------------- +# Load neural responses to speech and preprocess features. We include +# speech envelope, peak rate, and a "Null" noise band for validation. + +data = nl.io.load_speech_task_data() +n_trials = 3 +data = data[:n_trials] + +# Standardize neural responses +data['resp'] = nl.preprocessing.normalize(data=data, field='resp') + +# Step A: Compute auditory spectrogram and align to modeling rate (100Hz) +spec_fs, feat_fs = 11025, 100 +data['spec'] = [nl.features.auditory_spectrogram(trl['sound'], spec_fs) for trl in data] +# Resample spectrogram to match neural response length +data['spec'] = [resample(trial['spec'], trial['resp'].shape[0]) for trial in data] + +# Step B: Compute Envelope and Peak Rate (acoustic features) +data['env'] = [zscore(np.sum(trl['spec'], axis=1)) for trl in data] +data['peak_rate'] = [nl.features.peak_rate(trl['spec'], feat_fs, band=[1, 10]) for trl in data] + +# Step C: Final alignment and "Null" Noise Injection +# We inject noise to verify that BandedTRF assigns it a high lambda (regularization) +np.random.seed(1) +for i, trial in enumerate(data): + # Null Band: Gaussian noise scaled to match envelope variance + noise = np.random.randn(trial['resp'].shape[0]) + data[i]['noise'] = (noise / np.std(noise)) * np.std(data[i]['env']) + +############################################################################### +# 2. Visualize Stimulus Features +# ------------------------------ +# Check the temporal alignment of the envelope, peak rate, and injected noise. + +fig, ax = plt.subplots(figsize=(12, 3)) +t = np.arange(500) / feat_fs +ax.plot(t, data[0]['env'][:500], label='Envelope', color='#1f77b4') +ax.plot(t, data[0]['peak_rate'][:500], label='Peak Rate', color='#d62728') +ax.plot(t, data[0]['noise'][:500], label='Noise (Null)', color='#7f7f7f', alpha=0.5) +ax.set_title('Stimulus Features (First 5 Seconds)') +ax.set_xlabel('Time (s)') +ax.set_ylabel('Amplitude (z-score)') +ax.legend(loc='upper right', fontsize='small', ncol=3) +plt.show() + +############################################################################### +# 3. Fit Models with Injected Noise (Order Dependency) +# ---------------------------------------------------- +# BandedTRF uses a greedy, step-wise approach. We test if the order of +# feature entry affects the final predictive performance. + +tmin, tmax, sfreq = -0.2, 0.5, 100 +alphas = np.logspace(-2, 8, 11) + +# Fit Model 1: Envelope -> Noise -> Peak Rate +order_1 = ['env', 'noise', 'peak_rate'] +model1 = BandedTRF(tmin=tmin, tmax=tmax, sfreq=sfreq, alphas=alphas) +model1.fit(data=data, feature_order=order_1, target='resp') + +# Fit Model 2: Peak Rate -> Noise -> Envelope +order_2 = ['peak_rate', 'noise', 'env'] +model2 = BandedTRF(tmin=tmin, tmax=tmax, sfreq=sfreq, alphas=alphas) +model2.fit(data=data, feature_order=order_2, target='resp') + +############################################################################### +# 4. Alpha Optimization Paths (Marginal Delta R) +# ---------------------------------------------- +# Visualize how much each feature adds to the correlation (r) at each step. +# For the noise band, we expect a flat or negligible marginal improvement. + +colors = {'env': '#1f77b4', 'noise': '#7f7f7f', 'peak_rate': '#d62728'} +n_bands = len(order_1) + +for b_idx in range(n_bands): + fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=False) + for i, (mdl, ord_list) in enumerate(zip([model1, model2], [order_1, order_2])): + feat = ord_list[b_idx] + path = mdl.alpha_paths_[feat] + + # Calculate Delta R Path relative to the max R of the previous band + prev_r = 0 if b_idx == 0 else np.max(mdl.alpha_paths_[ord_list[b_idx-1]]) + delta_path = path - prev_r + + best_alpha = mdl.feature_alphas_[feat] + peak_delta = np.max(delta_path) + + axes[i].semilogx(alphas, delta_path, marker='o', color=colors[feat], label=f'Path: {feat}') + axes[i].plot(best_alpha, peak_delta, '*', markersize=14, markeredgecolor='k', label=f'Selected $\lambda$') + axes[i].set_title(f'Order {i+1} - Step {b_idx+1}: {feat}') + axes[i].set_xlabel(r'Regularization Alpha ($\lambda$)') + axes[i].legend() + + axes[0].set_ylabel(r'Marginal Improvement ($\Delta R$)') + plt.tight_layout() + plt.show() + +############################################################################### +# 5. Global Consistency: Order 1 vs Order 2 +# ----------------------------------------- +# A robust banded model should yield similar final predictive accuracies +# regardless of the order in which features were added. + +r_full_1 = model1.scores_[:,:,-1].mean(axis=0) +r_full_2 = model2.scores_[:,:,-1].mean(axis=0) + +fig, ax = plt.subplots(figsize=(5, 5)) +ax.scatter(r_full_1, r_full_2, s=50, alpha=0.6, edgecolors='w', color='purple') +# Set limits based on data range +min_r = min(r_full_1.min(), r_full_2.min()) +max_r = max(r_full_1.max(), r_full_2.max()) +lims = [min_r, max_r] +ax.plot(lims, lims, 'k--', alpha=0.5, label='Unity (Order Independent)') +ax.set_title('Cross-Order Consistency') +ax.set_xlabel('Mean Accuracy $r$ (Order 1)') +ax.set_ylabel('Mean Accuracy $r$ (Order 2)') +ax.legend() +plt.show() + +############################################################################### +# 6. Final Model Kernels for the Best Channel +# ------------------------------------------- +# Inspect temporal response functions (TRFs). The 'noise' band TRF should +# be close to zero, while 'env' and 'peak_rate' should show clear peaks. + +best_ch = np.argmax(r_full_1) +fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True) +lags = np.linspace(tmin, tmax, model1._ndelays) + +for i, (mdl, ord_list, title) in enumerate(zip([model1, model2], + [order_1, order_2], + ['Kernels (Order 1)', 'Kernels (Order 2)'])): + for f_idx, feat in enumerate(ord_list): + # Plot TRF with error shading across trials/CV folds + nl.visualization.shaded_error_plot( + lags, mdl.coef_[best_ch, f_idx, :], + ax=axes[i], color=colors[feat], + plt_args={'label': feat, 'lw': 2} + ) + + axes[i].axhline(0, color='black', alpha=0.5, linestyle=':') + axes[i].axvline(0, color='black', alpha=0.5, linestyle=':') + axes[i].set_title(f"{title} - Electrode {best_ch}") + axes[i].set_xlabel('Time Lag (s)') + axes[i].legend(fontsize='small', frameon=False) + +axes[0].set_ylabel('Filter Weight (a.u.)') +plt.tight_layout() +plt.show() + +# Statistical Significance Summary for the most responsive electrode +print(f"\nFinal Statistics for Model 1 (Order: {order_1}), Electrode {best_ch}:") +model1.summary(best_ch) + +print(f"\nFinal Statistics for Model 2 (Order: {order_2}), Electrode {best_ch}:") +model2.summary(best_ch) \ No newline at end of file diff --git a/naplib/encoding/__init__.py b/naplib/encoding/__init__.py index 502cc76b..25408afb 100644 --- a/naplib/encoding/__init__.py +++ b/naplib/encoding/__init__.py @@ -1,3 +1,9 @@ +''' +Models for encoding and decoding neural data, such as +Temporal Receptive Fields (TRFs). +''' + from .trf import TRF +from .banded_trf import BandedTRF -__all__ = ['TRF'] +__all__ = ['TRF', 'BandedTRF'] \ No newline at end of file diff --git a/naplib/encoding/banded_trf.py b/naplib/encoding/banded_trf.py new file mode 100644 index 00000000..975b41dd --- /dev/null +++ b/naplib/encoding/banded_trf.py @@ -0,0 +1,366 @@ +import numpy as np +import pandas as pd +from tqdm.auto import tqdm +from scipy.stats import ttest_1samp +from sklearn.base import BaseEstimator +from sklearn.linear_model import Ridge +from mne.decoding.receptive_field import _delay_time_series +from ..stats import pairwise_correlation +from ..utils import _parse_outstruct_args + +class BandedTRF(BaseEstimator): + r""" + Iterative Banded Ridge TRF model. + + Fits features sequentially in bands. For each band, the regularization (alpha) + is optimized via leave-one-trial-out cross-validation using coefficient averaging + for computational efficiency. + + Parameters + ---------- + tmin : float + Starting lag (seconds). + tmax : float + Ending lag (seconds). + sfreq : float + Sampling frequency (Hz). + alphas : np.ndarray, optional + Alphas to sweep for each feature. Default is np.logspace(-2, 5, 8). + basis_dict : dict, optional + Dictionary mapping feature names to basis objects. + """ + def __init__(self, tmin, tmax, sfreq, alphas=None, basis_dict=None): + self.tmin = tmin + self.tmax = tmax + self.sfreq = sfreq + self.alphas = alphas if alphas is not None else np.logspace(-2, 5, 8) + self.basis_dict = basis_dict if basis_dict is not None else {} + self.feature_alphas_ = {} + self.alpha_paths_ = {} + self.feature_order_ = [] + self.model_ = None # Will store a list of fitted Ridge models (one per trial) + self.target_ = None + self.scores_ = None # Shape: (n_trials, n_channels, n_features) + + @property + def _ndelays(self): + return int(round(self.tmax * self.sfreq)) - int(round(self.tmin * self.sfreq)) + 1 + + @property + def coef_(self): + if self.model_ is None: + raise AttributeError("BandedTRF has not been fitted yet.") + + n_trials = len(self.model_) + n_feats = len(self.feature_order_) + + # Force coefficients to be 2D (n_targets, n_features_total) + # This fixes the 3.8 vs 3.10 discrepancy + trial_coefs = [] + for m in self.model_: + c = m.coef_ + if c.ndim == 1: + c = c[np.newaxis, :] + trial_coefs.append(c) + + n_targets = trial_coefs[0].shape[0] + all_coefs = np.stack(trial_coefs, axis=-1) + + return all_coefs.reshape(n_targets, n_feats, self._ndelays, n_trials) + + def _prepare_matrix(self, X_list, feature_names, alphas_dict): + processed_trials = [] + n_trials = len(X_list[0]) + + for trl in range(n_trials): + mats = [] + for i, name in enumerate(feature_names): + x = X_list[i][trl] + + if isinstance(x, list) and len(x) == 1: + x = x[0] + + if np.isscalar(x) or x is None: + continue + if x.ndim == 1: + x = x[:, np.newaxis] + + if name in self.basis_dict: + x = self.basis_dict[name].transform(x) + + alpha = alphas_dict.get(name, 1.0) + mats.append(x / np.sqrt(alpha)) + + if not mats: + raise ValueError("No features were successfully processed.") + + concatenated = np.concatenate(mats, axis=1) + delayed = _delay_time_series(concatenated, self.tmin, self.tmax, self.sfreq) + processed_trials.append(delayed.reshape(delayed.shape[0], -1)) + return processed_trials + + def fit(self, data, feature_order, target='resp'): + r""" + Fit the Iterative Banded Ridge model using leave-one-trial-out cross-validation. + + The model fits features sequentially according to `feature_order`. For each + new feature band, an optimal regularization parameter (alpha) is selected + from `self.alphas` by maximizing the average prediction correlation across + held-out trials. + + Parameters + ---------- + data : naplib.OutStruct or list of dict + The data containing the features and target signal. Must be a format + compatible with `naplib.utils.parse_outstruct_args`. + feature_order : list of str + The ordered list of field names in `data` to be used as feature bands. + Features are added to the model sequentially. + target : str, default='resp' + The field name in `data` containing the dependent variable (e.g., + neural responses). + + Returns + ------- + self : BandedTRF + Returns the instance of the fitted model. + + Notes + ----- + The cross-validation uses 'coefficient averaging' for efficiency. For + each alpha in the sweep, a model is fit to each trial individually. + The prediction for a held-out trial $i$ is generated using the mean + coefficients of all trials $j \neq i$. + """ + self.feature_order_ = feature_order + self.target_ = target + + y = _parse_outstruct_args(data, target) + if not isinstance(y, list): y = [y] + + n_trials = len(y) + self.n_targets_ = y[0].shape[1] + + all_features_data = [] + for f in feature_order: + f_data = _parse_outstruct_args(data, f) + all_features_data.append(f_data if isinstance(f_data, list) else [f_data]) + + self.scores_ = np.zeros((n_trials, self.n_targets_, len(feature_order))) + + for i, current_feat in enumerate(feature_order): + best_alpha = None + max_r = -np.inf + r_history = [] + best_r_per_trial_ch = None + + for alpha in tqdm(self.alphas, desc=f"Optimizing {current_feat}", leave=False): + temp_alphas = {**self.feature_alphas_, current_feat: alpha} + X_mats = self._prepare_matrix(all_features_data[:i+1], feature_order[:i+1], temp_alphas) + + trial_betas = [Ridge(alpha=1.0).fit(tx, ty.reshape(-1, self.n_targets_)).coef_ for tx, ty in zip(X_mats, y)] + + current_alpha_trial_r = np.zeros((n_trials, self.n_targets_)) + for test_idx in range(n_trials): + train_indices = [j for j in range(n_trials) if j != test_idx] + avg_beta = np.mean([trial_betas[j] for j in train_indices], axis=0) + y_pred = X_mats[test_idx] @ avg_beta.T + + # Ensure y is 2D: (samples, targets) + y_true = y[test_idx] + if y_true.ndim == 1: + y_true = y_true[:, np.newaxis] + + # Ensure y_pred is 2D: (samples, targets) + if y_pred.ndim == 1: + y_pred = y_pred[:, np.newaxis] + + # This returns an array of shape (n_targets,) + r_values = pairwise_correlation(y_true, y_pred) + current_alpha_trial_r[test_idx, :] = r_values + + avg_r = np.nanmean(current_alpha_trial_r) + r_history.append(avg_r) + if avg_r > max_r or np.isclose(avg_r, max_r): + max_r, best_alpha = avg_r, alpha + best_r_per_trial_ch = current_alpha_trial_r + + self.feature_alphas_[current_feat] = best_alpha + self.alpha_paths_[current_feat] = np.array(r_history) + self.scores_[:, :, i] = best_r_per_trial_ch + + # Final fit on each trial separately + final_X = self._prepare_matrix(all_features_data, feature_order, self.feature_alphas_) + self.model_ = [Ridge(alpha=1.0).fit(tx, ty) for tx, ty in zip(final_X, y)] + + self.feat_dims_ = [] + for i, name in enumerate(feature_order): + x_sample = all_features_data[i][0] + if isinstance(x_sample, list): x_sample = x_sample[0] + if x_sample.ndim == 1: x_sample = x_sample[:, None] + if name in self.basis_dict: + x_sample = self.basis_dict[name].transform(x_sample) + self.feat_dims_.append(x_sample.shape[1]) + + return self + + def predict(self, data, feature_names=None): + """ + Predict target responses using the fitted Banded Ridge model. + + This method performs Leave-One-Trial-Out (LOTO) prediction. For each + trial in the input data, it averages the regression coefficients + from all *other* trials (fitted during training) to generate the + prediction for the current trial. + + Parameters + ---------- + data : naplib.OutStruct or list of dict + The data containing the features to predict from. Must contain + the same number of trials as used during `fit`. + feature_names : list of str, optional + The subset of features to use for prediction. If None (default), + uses all features specified in the `feature_order` during `fit`. + This allows for isolating the contribution of specific bands. + + Returns + ------- + preds : list of np.ndarray + Predicted target values for each trial. Each element is an + array of shape (n_samples, n_targets). + + Raises + ------ + ValueError + If the model has not been fitted, or if the number of trials + in `data` does not match the number of models in `self.model_`. + + Notes + ----- + Because this model stores a separate fit for every trial to enable + efficient cross-validation, the `predict` step requires the input + to have a one-to-one mapping with the training trials. + """ + if self.model_ is None: + raise ValueError("Model must be fitted before calling predict.") + + requested_features = feature_names if feature_names else self.feature_order_ + + # Standardize feature data to list of trial-lists + feat_data_list = [] + for f in requested_features: + f_data = _parse_outstruct_args(data, f) + feat_data_list.append(f_data if isinstance(f_data, list) else [f_data]) + + X_mats = self._prepare_matrix(feat_data_list, requested_features, self.feature_alphas_) + n_trials = len(X_mats) + + if n_trials != len(self.model_): + raise ValueError( + f"LOTO predict requires the same number of trials ({len(self.model_)}) " + f"as used in fit. Found {n_trials} trials." + ) + + all_coefs = np.array([m.coef_ for m in self.model_]) + if all_coefs.ndim == 2: + # Expand (trials, features) -> (trials, 1_target, features) + all_coefs = all_coefs[:, np.newaxis, :] + + # Handle feature masking if a subset is requested + mask = np.ones(all_coefs.shape[2], dtype=bool) + if feature_names is not None: + mask = np.zeros(all_coefs.shape[2], dtype=bool) + current_col = 0 + for i, f in enumerate(self.feature_order_): + num_cols = self.feat_dims_[i] * self._ndelays + if f in requested_features: + mask[current_col : current_col + num_cols] = True + current_col += num_cols + + preds = [] + for i in range(n_trials): + # Indices for all trials except the current one + loto_indices = [j for j in range(n_trials) if j != i] + + # Average coefficients and intercepts from the other trials + loto_coef = np.mean(all_coefs[loto_indices], axis=0) + + # Apply feature mask + sliced_coef = loto_coef[:, mask] + + # Predict for the current trial + preds.append(X_mats[i] @ sliced_coef.T) + + return preds + + def summary(self, channel=None): + r""" + Generate a statistical report of feature contributions and model performance. + + Calculates the incremental improvement (Delta R) for each feature band + added to the model and performs a one-sample t-test (alternative='greater') + across trials to determine if the contribution is significantly greater + than zero. + + Parameters + ---------- + channel : int, optional + The specific target channel (e.g., electrode or sensor) to summarize. + If None (default), results are averaged across all channels. + + Returns + ------- + df : pandas.DataFrame + A summary table indexed by 'Feature' containing: + - Total R: Cumulative correlation after adding this feature. + - Delta R: Incremental correlation increase attributed to this feature. + - Alpha: The optimized regularization parameter for the band. + - p-value: Significance of the Delta R across trials (t-test). + + Notes + ----- + The Delta R for the first feature is its Total R. For subsequent + features, Delta R is calculated as: + $ \Delta R_{n} = R_{n} - R_{n-1} $ + + Significant p-values suggest that the addition of a specific feature + band significantly improves the model's predictive power on + held-out data. + """ + if self.scores_ is None: + raise ValueError("Model must be fitted before calling summary.") + + dr_tensor = np.diff(self.scores_, axis=2, prepend=0) + + if channel is not None: + r_report = self.scores_[:, channel, :] + dr_report = dr_tensor[:, channel, :] + ch_label = f"Channel {channel}" + else: + r_report = np.nanmean(self.scores_, axis=1) + dr_report = np.nanmean(dr_tensor, axis=1) + ch_label = "Global Mean (All Channels)" + + summary_results = [] + for f_idx, feat in enumerate(self.feature_order_): + sample = dr_report[:, f_idx] + clean_sample = sample[~np.isnan(sample)] + if len(clean_sample) < 2 or np.all(clean_sample == clean_sample[0]): + p_val = 1.0 if np.mean(clean_sample) <= 0 else 0.0 + else: + _, p_val = ttest_1samp(clean_sample, 0, alternative='greater') + + summary_results.append({ + 'Feature': feat, + 'Total R': np.nanmean(r_report[:, f_idx]), + 'Delta R': np.nanmean(dr_report[:, f_idx]), + 'Alpha': self.feature_alphas_[feat], + 'p-value': p_val, + }) + + df = pd.DataFrame(summary_results).set_index('Feature') + print(f"\nBandedTRF Summary | {ch_label}\n" + "-" * 70) + print(df.to_string(formatters={'Total R': '{:,.4f}'.format, + 'Delta R': '{:,.4f}'.format, + 'Alpha': '{:,.2e}'.format})) + return df \ No newline at end of file diff --git a/naplib/io/load_bids.py b/naplib/io/load_bids.py index 3cedd5bb..1f056684 100644 --- a/naplib/io/load_bids.py +++ b/naplib/io/load_bids.py @@ -3,13 +3,14 @@ from naplib import logger from ..data import Data -ACCEPTED_CROP_BY = ['onset', 'durations'] +ACCEPTED_CROP_BY = ['onset', 'durations', None] def load_bids(root, subject, datatype, task, suffix, + run=None, session=None, befaft=[0, 0], crop_by='onset', @@ -33,6 +34,8 @@ def load_bids(root, Task name. suffix : string Suffix name in file naming. This is often the same as datatype. + run : string + Run name. session : string Session name. befaft : list or array-like or length 2, default=[0, 0] @@ -89,7 +92,7 @@ def load_bids(root, raise ValueError(f'Invalid "crop_by" input. Expected one of {ACCEPTED_CROP_BY} but got "{crop_by}"') bids_path = BIDSPath(subject=subject, root=root, session=session, task=task, - suffix=suffix, datatype=datatype) + run=run, suffix=suffix, datatype=datatype) raw = read_raw_bids(bids_path=bids_path) @@ -123,8 +126,9 @@ def load_bids(root, for trial in tqdm(range(len(raws))): trial_data = {} trial_data['event_index'] = trial - if 'description' in raw_responses[trial].annotations[0]: - trial_data['description'] = raw_responses[trial].annotations[0]['description'] + if raw_responses[trial].annotations: + if 'description' in raw_responses[trial].annotations[0]: + trial_data['description'] = raw_responses[trial].annotations[0]['description'] if raw_stims[trial] is not None: trial_data['stim'] = raw_stims[trial].get_data().transpose(1,0) # time by channels trial_data['stim_ch_names'] = raw_stims[trial].info['ch_names'] @@ -138,7 +142,8 @@ def load_bids(root, new_data.append(trial_data) data_ = Data(new_data, strict=False) - data_.set_mne_info(raw_info) + if raw_info is not None: + data_.set_mne_info(raw_info) return data_ @@ -151,14 +156,14 @@ def _crop_raw_bids(raw_instance, crop_by, befaft): raw_instance : mne.io.Raw-like object crop_by : string, default='onset' - One of ['onset', 'annotations']. If crop by 'onset', each trial is split + One of ['onset', 'annotations', None]. If crop by 'onset', each trial is split by the onset of each event defined in the BIDS file structure and each trial ends when the next trial begins. If crop by 'annotations', each trial is split by the onset of each event defined in the BIDS file structure and each trial lasts the duration specified by the event. This is typically not desired when the events are momentary stimulus presentations that have very short duration because only the responses during the short duration of the event will be saved, and - all of the following responses are truncated. + all of the following responses are truncated. If None, no cropping. Returns ------- @@ -166,6 +171,8 @@ def _crop_raw_bids(raw_instance, crop_by, befaft): The cropped raw objects. ''' + if crop_by == None: + return [raw_instance.copy()] max_time = (raw_instance.n_times - 1) / raw_instance.info['sfreq'] diff --git a/naplib/stats/__init__.py b/naplib/stats/__init__.py index 88e0d756..4c87425f 100644 --- a/naplib/stats/__init__.py +++ b/naplib/stats/__init__.py @@ -1,7 +1,7 @@ -from .encoding import discriminability +from .encoding import discriminability, pairwise_correlation from .mixedeffectsmodel import LinearMixedEffectsModel from .pvalues import stars from .responsive_ttest import responsive_ttest from .ttest import ttest -__all__ = ['discriminability','LinearMixedEffectsModel','stars','responsive_ttest', 'ttest'] +__all__ = ['discriminability','pairwise_correlation','LinearMixedEffectsModel','stars','responsive_ttest', 'ttest'] diff --git a/naplib/stats/encoding.py b/naplib/stats/encoding.py index eb9b83a0..6fa4d8d2 100644 --- a/naplib/stats/encoding.py +++ b/naplib/stats/encoding.py @@ -210,3 +210,56 @@ def _compute_discrim(x_data, labels_data): return f_stat + +import numpy as np + +import numpy as np + +def pairwise_correlation(A, B, axis=0): + r""" + Compute Pearson correlation between A and B along a specified axis. + + The correlation is computed pairwise for each corresponding element + along the remaining dimensions. The output will have the same shape + as the inputs, but with the specified ``axis`` removed. + + The correlation is calculated as: + $$r = \frac{\sum (A_i - \bar{A})(B_i - \bar{B})}{\sqrt{\sum (A_i - \bar{A})^2 \sum (B_i - \bar{B})^2}}$$ + + Parameters + ---------- + A : np.ndarray + First array. + B : np.ndarray + Second array. Must be the same shape as A. + axis : int, default=0 + The axis along which to compute the correlation (e.g., the time dimension). + + Returns + ------- + corr : np.ndarray or float + Pairwise correlations. If inputs are 1D, returns a float. + Otherwise, returns an array of shape equal to the input shape + with the ``axis`` dimension removed. + """ + A = np.asarray(A) + B = np.asarray(B) + + if A.shape != B.shape: + raise ValueError(f"A and B must have the same shape, but got {A.shape} and {B.shape}") + + # 1. Center the data along the specified axis + # keepdims=True is essential for broadcasting subtraction + am = A - np.mean(A, axis=axis, keepdims=True) + bm = B - np.mean(B, axis=axis, keepdims=True) + + # 2. Compute sum of squares (variance proxies) + a_ss = np.sum(am**2, axis=axis) + b_ss = np.sum(bm**2, axis=axis) + + # 3. Compute covariance proxy + coscale = np.sum(am * bm, axis=axis) + + # 4. Return normalized correlation + # 1e-15 prevents division by zero for constant signals + return coscale / (np.sqrt(a_ss * b_ss) + 1e-15) \ No newline at end of file diff --git a/tests/encoding/test_banded_trf.py b/tests/encoding/test_banded_trf.py new file mode 100644 index 00000000..c755de09 --- /dev/null +++ b/tests/encoding/test_banded_trf.py @@ -0,0 +1,89 @@ +import pytest +import numpy as np +import pandas as pd +from sklearn.linear_model import Ridge +from naplib import Data +from naplib.encoding import BandedTRF +from naplib.stats import pairwise_correlation + +@pytest.fixture(scope='module') +def synth_data(): + """ + Generate synthetic data with 2 target channels. + This ensures that Ridge.coef_ returns a 2D array (n_targets, n_features), + making the stacked 'all_coefs' 3D (n_trials, n_targets, n_features) + and preventing IndexErrors in the masking logic. + """ + rng = np.random.default_rng(42) + fs, n_samples, n_trials = 100, 1000, 3 + trials = [] + for _ in range(n_trials): + x1 = rng.standard_normal(size=(n_samples, 1)) + x2 = rng.standard_normal(size=(n_samples, 1)) + + # Create 2 target channels (multi-output) + y1 = (x1 * 1.0 + np.roll(x2, 2) * 0.5) + y2 = (x1 * 0.5 + np.roll(x2, 1) * 1.0) + resp = np.hstack([y1, y2]) + 0.01 * rng.standard_normal((n_samples, 2)) + + trials.append({'resp': resp, 'stim1': x1, 'stim2': x2}) + + return { + 'data': Data(trials), + 'feature_order': ['stim1', 'stim2'], + 'tmin': 0, 'tmax': 0.03, 'sfreq': fs + } + +def test_banded_trf_loto_consistency(synth_data): + """Test that coef_ property handles the 4D reshape correctly.""" + model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], + sfreq=synth_data['sfreq'], alphas=[0.1, 10.0]) + model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp') + + # Shape calculation: 2 targets, 2 features, 4 delays, 3 trials. + # ndelays = (0.03 * 100) - (0 * 100) + 1 = 4. + assert model.coef_.shape == (2, 2, 4, 3) + +def test_predict_masking_logic(synth_data): + """Verify that partial feature prediction works with multi-channel targets.""" + model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq']) + model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp') + + # Full prediction: should match target shape (samples, channels) + preds_all = model.predict(synth_data['data']) + assert len(preds_all) == 3 + assert preds_all[0].shape == (1000, 2) + + # Partial prediction: Triggers the internal mask logic + # multi-channel data ensures all_coefs.ndim == 3, avoiding IndexError + preds_sub = model.predict(synth_data['data'], feature_names=['stim1']) + assert len(preds_sub) == 3 + assert preds_sub[0].shape == (1000, 2) + +def test_summary_p_values(synth_data): + """Verify summary table computes stats across channels correctly.""" + model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq']) + model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp') + + df = model.summary() + assert isinstance(df, pd.DataFrame) + assert 'Delta R' in df.columns + assert 'p-value' in df.columns + # Check that p-values are valid numbers + assert not df['p-value'].isna().any() + +def test_unfitted_attribute_error(): + """Verify custom AttributeError message for unfitted models.""" + model = BandedTRF(0, 0.1, 100) + with pytest.raises(AttributeError, match="BandedTRF has not been fitted yet."): + _ = model.coef_ + +def test_predict_trial_mismatch(synth_data): + """LOTO requires the same number of trials for predict as fit.""" + model = BandedTRF(tmin=synth_data['tmin'], tmax=synth_data['tmax'], sfreq=synth_data['sfreq']) + model.fit(data=synth_data['data'], feature_order=synth_data['feature_order'], target='resp') + + # Try predicting with only 2 trials instead of 3 + short_data = synth_data['data'][:2] + with pytest.raises(ValueError, match="LOTO predict requires the same number of trials"): + model.predict(short_data) \ No newline at end of file diff --git a/tests/stats/test_pairwise_correlation.py b/tests/stats/test_pairwise_correlation.py new file mode 100644 index 00000000..05b99fef --- /dev/null +++ b/tests/stats/test_pairwise_correlation.py @@ -0,0 +1,69 @@ +import numpy as np +import pytest +from naplib.stats import pairwise_correlation + +def test_pairwise_correlation_1d(): + # Identical 1D signals + a = np.array([1.0, 2.0, 3.0, 4.0]) + b = np.array([1.0, 2.0, 3.0, 4.0]) + corr = pairwise_correlation(a, b) + assert np.isclose(corr, 1.0, atol=1e-8) + + # Inverse 1D signals + b_inv = -b + corr_inv = pairwise_correlation(a, b_inv) + assert np.isclose(corr_inv, -1.0, atol=1e-8) + +def test_pairwise_correlation_2d_default_axis(): + # Default axis=0 (correlate columns over rows) + A = np.array([[1, 2], + [2, 1], + [3, 3]]) + B = A.copy() + corr = pairwise_correlation(A, B) # Expected shape (2,) + assert corr.shape == (2,) + assert np.allclose(corr, [1.0, 1.0], atol=1e-8) + +def test_pairwise_correlation_2d_custom_axis(): + # axis=1 (correlate rows over columns) + A = np.array([[1, 2, 3], + [4, 5, 6]]) + B = A.copy() + corr = pairwise_correlation(A, B, axis=1) # Expected shape (2,) + assert corr.shape == (2,) + assert np.allclose(corr, [1.0, 1.0], atol=1e-8) + +def test_pairwise_correlation_3d_neural_data(): + # Simulation of (trials, channels, time) correlating over time axis + rng = np.random.default_rng(42) + A = rng.standard_normal((5, 10, 100)) # 5 trials, 10 channels, 100 samples + B = A.copy() + + # Correlate over time (axis=2) + corr = pairwise_correlation(A, B, axis=2) + assert corr.shape == (5, 10) # One correlation per trial/channel + assert np.allclose(corr, 1.0, atol=1e-8) + +def test_pairwise_correlation_shape_mismatch(): + A = np.ones((10, 2)) + B = np.ones((10, 3)) + with pytest.raises(ValueError, match="A and B must have the same shape"): + pairwise_correlation(A, B) + +def test_pairwise_correlation_zero_variance(): + # Test epsilon handling for constant signals (prevents nan) + A = np.array([1.0, 1.0, 1.0]) + B = np.array([2.0, 2.0, 2.0]) + corr = pairwise_correlation(A, B) + # With 1e-15 in denominator and 0 in numerator, result is 0 + assert np.isclose(corr, 0.0, atol=1e-8) + +def test_pairwise_correlation_random_precision(): + # Test against np.corrcoef for a single pair to ensure mathematical parity + rng = np.random.default_rng(1) + a = rng.standard_normal(100) + b = rng.standard_normal(100) + + expected = np.corrcoef(a, b)[0, 1] + actual = pairwise_correlation(a, b) + assert np.isclose(actual, expected, atol=1e-8) \ No newline at end of file diff --git a/tests/test_brain_object.py b/tests/test_brain_object.py index dfdd085c..f6e08764 100644 --- a/tests/test_brain_object.py +++ b/tests/test_brain_object.py @@ -207,7 +207,8 @@ def test_plotly_electrode_coloring(data): def test_plotly_electrode_coloring_by_value(data): colors = ['k' if isL else 'r' for isL in data['isleft']] - fig, axes = plot_brain_elecs(data['brain_inflated'], data['coords'], data['isleft'], values=data['isleft'], vmin=-1, vmax=2, cmap='binary', hemi='both', view='medial', backend='plotly') + fig, axes = plot_brain_elecs(data['brain_inflated'], data['coords'], data['isleft'], values=data['isleft'], + vmin=-1, vmax=2, cmap='binary', hemi='both', view='medial', backend='plotly') assert len(fig.data) == 4 assert fig.data[0]['x'].shape == (163842,) assert fig.data[0]['facecolor'].shape == (327680, 4) @@ -249,9 +250,87 @@ def test_set_visible(data): ending_visible = brain_pial1.lh.alpha assert (ending_visible.sum() == 327680) - - - - - +def test_interpolate_electrodes_onto_brain(data): + """ + Test interpolation of electrode values onto the cortical surface. + """ + brain = data['brain_pial'] # Brain object containing lh and rh + lh = brain.lh + lh.reset_overlay() + + # 1. Setup: Place one active electrode near a known vertex + # Electrode 0 is at [-47.28, 16.29, -15.82] + coords = data['coords'][:1] + values = np.array([10.0]) + + # 2. Run interpolation + # k=1 (nearest neighbor), max_dist=10mm + lh.interpolate_electrodes_onto_brain(coords, values, k=1, max_dist=10, roi='all') + + # Check that vertices near the electrode have the value, and others are 0/nan + # Note: the code sets self.overlay[updated_vertices] = smoothed_values + assert np.isclose(lh.overlay.max(), 10.0) + assert np.any(lh.overlay == 10.0) + + # 3. Test ROI filtering + lh.reset_overlay() + # Use a ROI that doesn't exist near the electrode + lh.interpolate_electrodes_onto_brain(coords, values, k=1, max_dist=10, roi=['G_front_middle']) + + # If electrode 0 is in STG and we only allow Middle Frontal Gyrus, overlay should stay 0 + # (Assuming electrode 0 isn't in G_front_middle) + if 'G_front_middle' not in lh.num2label[lh.labels[0]]: + assert np.all(lh.overlay == 0) + +def test_interpolation_inverse_distance_weighting(data): + """ + Test that the weighting logic correctly averages two electrodes. + """ + lh = data['brain_pial'].lh + lh.reset_overlay() + + # Two electrodes: one with 10.0, one with 0.0, same distance from a vertex + # We'll mock this by providing coordinates equidistant to a specific point + target_vertex_coord = lh.coords[500] + offset = np.array([2, 0, 0]) + coords = np.array([target_vertex_coord + offset, target_vertex_coord - offset]) + values = np.array([10.0, 0.0]) + + lh.interpolate_electrodes_onto_brain(coords, values, k=2, max_dist=10) + + # At the midpoint (the vertex), the value should be the mean (5.0) + # because weights are 1/dist and distances are equal. + assert np.isclose(lh.overlay[500], 5.0, atol=0.1) + +def test_parcellate_overlay(data): + """ + Test that parcellation correctly merges vertex values into parcel-wide values. + """ + lh = data['brain_pial'].lh + lh.reset_overlay() + + # 1. Manually "paint" some vertices in a specific parcel + target_label_num = 10 + target_label_name = lh.num2label[target_label_num] + mask = lh.labels == target_label_num + + # Set half the vertices in this parcel to 100, the other half to 0 + indices = np.where(mask)[0] + mid = len(indices) // 2 + lh.overlay[indices[:mid]] = 100.0 + lh.overlay[indices[mid:]] = 0.0 + + # 2. Run parcellation with mean + lh.parcellate_overlay(merge_func=np.mean) + + # 3. All vertices in that parcel should now be the mean (50.0) + assert np.allclose(lh.overlay[mask], 50.0) + + # 4. Check with a different merge function (max) + lh.reset_overlay() + lh.overlay[indices[:mid]] = 10.0 + lh.overlay[indices[mid:]] = 50.0 + lh.parcellate_overlay(merge_func=np.max) + + assert np.allclose(lh.overlay[mask], 50.0)