From fbc3c52e5cde204790b44ef392532cf50b1d94ef Mon Sep 17 00:00:00 2001 From: Hritik Arasu Date: Tue, 7 Oct 2025 13:33:15 -0700 Subject: [PATCH] Enhance reliability mask calculation by introducing new functions for train/test split analysis. - Added `compute_reliability_with_train_test_split()` to compute voxel reliability using only training set images, ensuring train/test purity and averaging across all available repeats. - Implemented `validate_reliability_calculation()` for comprehensive quality control checks on reliability scores. - Created `compare_reliability_methods()` to compare old and new reliability calculation methods. - Updated `main-multisession-3tasks.ipynb` with new cells for computing reliability scores, performing QC, and comparing methods. - Improved documentation in `RELIABILITY_MASK_IMPROVEMENTS.md` to detail changes and usage examples. --- RELIABILITY_MASK_IMPROVEMENTS.md | 246 +++++++++++++++++++++++++++++++ main-multisession-3tasks.ipynb | 4 +- utils.py | 238 ++++++++++++++++++++++++++++++ 3 files changed, 486 insertions(+), 2 deletions(-) create mode 100644 RELIABILITY_MASK_IMPROVEMENTS.md diff --git a/RELIABILITY_MASK_IMPROVEMENTS.md b/RELIABILITY_MASK_IMPROVEMENTS.md new file mode 100644 index 000000000..bdadd9300 --- /dev/null +++ b/RELIABILITY_MASK_IMPROVEMENTS.md @@ -0,0 +1,246 @@ +# Reliability Mask Calculation Improvements + +## Summary + +Improved the reliability mask calculation to address two critical issues with the original implementation: + +1. **All Repeats Averaging**: Now averages across all available repeat presentations (not just first 2) +2. **Train/Test Purity**: Only uses training set images for reliability calculation to prevent test set contamination + +## Changes Made + +### 1. New Function in `utils.py` + +Added `compute_reliability_with_train_test_split()` which: +- Filters pairs to only include those where ALL repeats are in the training set +- Computes pairwise correlations for all unique combinations of repeats +- Averages across all pairwise correlations (not just one pair) +- Returns reliability scores and the filtered pairs used + +```python +rels, train_pairs = utils.compute_reliability_with_train_test_split( + vox=vox, + pairs=pairs, + train_indices=train_image_indices, + test_indices=test_image_indices, + verbose=True +) +``` + +### 2. QC/Validation Functions + +Added two validation functions: + +**`validate_reliability_calculation()`**: Comprehensive QC checks including: +- Reliability score statistics (mean, std, range, percentiles) +- Distribution analysis (% positive voxels, NaN counts) +- Train/test purity verification (overlap checks) +- Repeat usage statistics +- Automatic warning detection + +**`compare_reliability_methods()`**: Side-by-side comparison of: +- Old method (first 2 repeats only, all pairs) +- New method (all repeats, train-only pairs) +- Correlation between methods +- Mean difference analysis + +### 3. Updated Notebook (`main-multisession-3tasks.ipynb`) + +Added new cells after the train/test split section: +- Markdown header explaining the improvements +- Cell to compute new reliability scores with QC +- Cell to compare old vs new methods with visualizations +- Cell to assign the new scores to `rels` for downstream use +- Updated old method header with deprecation notice + +## How It Works + +### Old Method Issues + +```python +# Only uses first 2 repeats +pairs_homog = np.array([[p[0], p[1]] for p in pairs]) + +# No train/test filtering - uses ALL pairs +vox_pairs = utils.zscore(vox[pairs_homog]) +rels = np.full(vox.shape[-1],np.nan) +for v in tqdm(range(vox.shape[-1])): + rels[v] = np.corrcoef(vox_pairs[:,0,v], vox_pairs[:,1,v])[1,0] +``` + +**Problems:** +1. If an image has 3+ repeats, only uses first 2 → wastes information +2. Includes test set images in reliability calculation → contamination +3. No verification of train/test purity + +### New Method Solution + +```python +# Filter to train-only pairs +train_pairs = [] +for pair in pairs: + if all(idx in train_set for idx in pair): + train_pairs.append(pair) + +# For each voxel, compute correlations for ALL pairwise combinations +for v in range(n_voxels): + all_corrs = [] + for pair in train_pairs: + pair_responses = vox[pair, v] + # Z-score within this pair + pair_responses = zscore(pair_responses) + + # All unique pairs of repeats + combos = list(itertools.combinations(range(len(pair)), 2)) + for i, j in combos: + r = np.corrcoef(pair_responses[i], pair_responses[j])[0, 1] + all_corrs.append(r) + + # Average across all correlations + rels[v] = np.mean(all_corrs) +``` + +**Benefits:** +1. Uses ALL available repeats → more robust estimates +2. Train/test pure → no contamination +3. Automatically validated with QC checks + +## Usage Example + +### Running the Improved Calculation + +In the notebook, after defining train/test splits: + +```python +# Compute reliability with train/test purity +rels_new, train_pairs_used = utils.compute_reliability_with_train_test_split( + vox=vox, + pairs=pairs, + train_indices=train_image_indices, + test_indices=test_image_indices, + verbose=True +) + +# Run QC validation +qc_results = utils.validate_reliability_calculation( + rels=rels_new, + vox=vox, + pairs=train_pairs_used, + train_indices=train_image_indices, + test_indices=test_indices +) + +# Compare with old method +comparison, rels_old, _ = utils.compare_reliability_methods( + vox=vox, + pairs=pairs, + train_indices=train_image_indices, + test_indices=test_image_indices +) + +# Use the new reliability scores +rels = rels_new +``` + +### Expected QC Output + +``` +================================================== +RELIABILITY QC SUMMARY +================================================== +Reliability Statistics: + Mean ± Std: 0.XXXX ± 0.XXXX + Median: 0.XXXX + Range: [-X.XXXX, X.XXXX] + Q25-Q75: [0.XXXX, 0.XXXX] + % Positive: XX.X% + NaN voxels: X + +Repeat Usage: + Total pairs: XXX + Repeats per pair: 2-X (mean: X.X) + +Train/Test Purity: + Train/test overlap: 0 indices + Pure split: ✓ PASS + Test contamination in pairs: 0 indices + +Warnings: + ✓ All checks passed +================================================== +``` + +## Next Steps for Testing + +To validate on sub-005 ses-03: + +1. **Set parameters in notebook:** + ```python + sub = 'sub-005' + session = 'ses-03' + train_test_split = 'MST' # or whatever split you're using + ``` + +2. **Run through the notebook** up to the new reliability calculation cells + +3. **Check QC output:** + - Verify reliability mean is positive and reasonable (typically 0.1-0.5) + - Confirm train/test purity passes (0 overlap) + - Check that multiple repeats are being used (mean > 2.0) + - Compare correlation between old and new methods (should be high, > 0.8) + +4. **Visual inspection:** + - Histogram should show reasonable distribution + - Scatter plot should show strong correlation with old method + - Difference plot should be centered near 0 + +5. **Launch training** with the new reliability mask + +## Implementation Details + +### Handling Edge Cases + +- **No train pairs**: Returns NaN array if no pairs have all repeats in training set +- **Mixed repeat counts**: Handles pairs with different numbers of repeats (2, 3, 4, etc.) +- **NaN handling**: Skips NaN correlations when averaging + +### Performance Considerations + +- Uses tqdm for progress tracking (can disable with `verbose=False`) +- Computes correlations on-the-fly rather than storing all responses +- Memory efficient for large voxel counts + +### Z-scoring Strategy + +Z-scoring is done **per pair** before computing correlations: +- Normalizes for differences in overall activation levels +- Makes correlations comparable across pairs +- Consistent with original implementation + +## Testing Checklist + +Before launching full training: + +- [ ] Function returns expected shape +- [ ] No NaN values in reliability scores (unless expected) +- [ ] Train/test purity check passes (0 overlap) +- [ ] Mean reliability is positive and reasonable +- [ ] High correlation with old method (> 0.8) +- [ ] Visualizations look reasonable +- [ ] All available repeats are being used +- [ ] QC warnings pass + +## Future Improvements + +For handling multiple sessions with inhomogeneous repeats: + +1. **Session-specific reliability**: Compute reliability per session, then aggregate +2. **Weighted averaging**: Weight by number of repeats available +3. **Cross-session validation**: Use one session's reliability to validate another +4. **Adaptive thresholding**: Different thresholds per session based on repeat counts + +## References + +- Original `compute_avg_repeat_corrs()` function (line 826 in utils.py) +- Notebook section: "Reliability calculation" (cells #VSC-bfa240ac onwards) +- Train/test split logic (cell #VSC-99479ec0) diff --git a/main-multisession-3tasks.ipynb b/main-multisession-3tasks.ipynb index eeadbc005..f2c814700 100644 --- a/main-multisession-3tasks.ipynb +++ b/main-multisession-3tasks.ipynb @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:be031750ac90b0c1dfd9030954aa5d8fb091851051ffc2446186ad6d2f40250d -size 1055888 +oid sha256:926706dd68eb5a699e0f0f8368f26cafc69a5e6512e9d3d4fe0e26a5a03d3488 +size 1059923 diff --git a/utils.py b/utils.py index 451c6b0c8..52076818f 100755 --- a/utils.py +++ b/utils.py @@ -849,6 +849,244 @@ def compute_avg_repeat_corrs(vox_repeats: np.ndarray) -> np.ndarray: return rels +def compute_reliability_with_train_test_split(vox, pairs, train_indices, test_indices=None, verbose=True): + """ + Compute voxel reliability using only training set images to maintain train/test purity. + Averages correlation across all available repeats (not just first 2). + + Args: + vox: array of shape (n_images, n_voxels) containing voxel responses + pairs: list of lists, where each sublist contains indices of repeated presentations of the same image + train_indices: array of indices corresponding to training set images + test_indices: optional array of test set indices (used for validation checks) + verbose: whether to print diagnostic information + + Returns: + rels: array of shape (n_voxels,) containing reliability scores for each voxel + pairs_used: list of pairs that were used for reliability calculation (only train set) + + Note: + - Only pairs where ALL repeats are in the training set are used for reliability calculation + - This ensures train/test purity: no test set images contaminate reliability estimates + - Averages across all available pairwise correlations between repeats + """ + import itertools + from tqdm import tqdm + + train_set = set(train_indices) + if test_indices is not None: + test_set = set(test_indices) + # Verify no overlap between train and test + assert len(train_set & test_set) == 0, "Train and test sets must not overlap!" + + # Filter pairs to only include those where ALL repeats are in training set + train_pairs = [] + for pair in pairs: + if all(idx in train_set for idx in pair): + train_pairs.append(pair) + + if verbose: + print(f"Total pairs: {len(pairs)}") + print(f"Pairs with all repeats in training set: {len(train_pairs)}") + if len(train_pairs) == 0: + print("WARNING: No pairs found with all repeats in training set!") + print("This may indicate an issue with the train/test split or pairs structure.") + + if len(train_pairs) == 0: + # Return NaN reliability if no valid pairs + return np.full(vox.shape[-1], np.nan), [] + + # For each pair, z-score the voxel responses and compute correlations + n_voxels = vox.shape[-1] + rels = np.full(n_voxels, np.nan) + + # Collect all repeat responses for each voxel + # We'll compute correlations across all pairs + for v in tqdm(range(n_voxels), disable=not verbose, desc="Computing reliability"): + all_corrs = [] + + for pair in train_pairs: + # Get voxel responses for all repeats of this pair + pair_responses = vox[pair, v] + + # Z-score the responses for this pair + pair_responses = (pair_responses - np.mean(pair_responses)) / (np.std(pair_responses) + 1e-8) + + # Compute correlation for all unique pairwise combinations of repeats + n_repeats = len(pair) + combos = list(itertools.combinations(range(n_repeats), 2)) + + for i, j in combos: + r = np.corrcoef(pair_responses[i], pair_responses[j])[0, 1] + if not np.isnan(r): + all_corrs.append(r) + + # Average across all pairwise correlations from all pairs + if len(all_corrs) > 0: + rels[v] = np.mean(all_corrs) + + if verbose: + print(f"\nReliability statistics:") + print(f" Mean: {np.nanmean(rels):.4f}") + print(f" Std: {np.nanstd(rels):.4f}") + print(f" Min: {np.nanmin(rels):.4f}") + print(f" Max: {np.nanmax(rels):.4f}") + print(f" NaN voxels: {np.sum(np.isnan(rels))} / {n_voxels}") + + return rels, train_pairs + + +def validate_reliability_calculation(rels, vox, pairs, train_indices, test_indices=None): + """ + Quality control checks for reliability calculation. + + Args: + rels: reliability scores for each voxel + vox: voxel data (n_images, n_voxels) + pairs: list of repeated image indices + train_indices: training set indices + test_indices: optional test set indices + + Returns: + dict with QC results and metrics + """ + qc_results = {} + + # Check 1: Reliability scores are in reasonable range + qc_results['mean_reliability'] = np.nanmean(rels) + qc_results['std_reliability'] = np.nanstd(rels) + qc_results['min_reliability'] = np.nanmin(rels) + qc_results['max_reliability'] = np.nanmax(rels) + qc_results['n_nan_voxels'] = np.sum(np.isnan(rels)) + qc_results['pct_positive'] = np.sum(rels > 0) / len(rels) * 100 + + # Check 2: Distribution analysis + qc_results['median_reliability'] = np.nanmedian(rels) + qc_results['q25_reliability'] = np.nanpercentile(rels, 25) + qc_results['q75_reliability'] = np.nanpercentile(rels, 75) + + # Check 3: Verify train/test purity + train_set = set(train_indices) + if test_indices is not None: + test_set = set(test_indices) + overlap = train_set & test_set + qc_results['train_test_overlap'] = len(overlap) + qc_results['train_test_pure'] = (len(overlap) == 0) + + # Check that pairs used only contain train indices + all_pair_indices = set() + for pair in pairs: + all_pair_indices.update(pair) + test_contamination = all_pair_indices & test_set + qc_results['test_contamination_in_pairs'] = len(test_contamination) + + # Check 4: Count repeats usage + repeat_counts = [len(pair) for pair in pairs] + qc_results['min_repeats'] = min(repeat_counts) if repeat_counts else 0 + qc_results['max_repeats'] = max(repeat_counts) if repeat_counts else 0 + qc_results['mean_repeats'] = np.mean(repeat_counts) if repeat_counts else 0 + qc_results['total_pairs_used'] = len(pairs) + + # Print summary + print("\n" + "="*50) + print("RELIABILITY QC SUMMARY") + print("="*50) + print(f"Reliability Statistics:") + print(f" Mean ± Std: {qc_results['mean_reliability']:.4f} ± {qc_results['std_reliability']:.4f}") + print(f" Median: {qc_results['median_reliability']:.4f}") + print(f" Range: [{qc_results['min_reliability']:.4f}, {qc_results['max_reliability']:.4f}]") + print(f" Q25-Q75: [{qc_results['q25_reliability']:.4f}, {qc_results['q75_reliability']:.4f}]") + print(f" % Positive: {qc_results['pct_positive']:.1f}%") + print(f" NaN voxels: {qc_results['n_nan_voxels']}") + + print(f"\nRepeat Usage:") + print(f" Total pairs: {qc_results['total_pairs_used']}") + print(f" Repeats per pair: {qc_results['min_repeats']}-{qc_results['max_repeats']} (mean: {qc_results['mean_repeats']:.1f})") + + if test_indices is not None: + print(f"\nTrain/Test Purity:") + print(f" Train/test overlap: {qc_results['train_test_overlap']} indices") + print(f" Pure split: {'✓ PASS' if qc_results['train_test_pure'] else '✗ FAIL'}") + print(f" Test contamination in pairs: {qc_results['test_contamination_in_pairs']} indices") + + # Warnings + print(f"\nWarnings:") + warnings = [] + if qc_results['mean_reliability'] < 0: + warnings.append("Mean reliability is negative!") + if qc_results['mean_reliability'] > 0.8: + warnings.append("Mean reliability is suspiciously high (>0.8)") + if qc_results['pct_positive'] < 50: + warnings.append("Less than 50% of voxels have positive reliability") + if test_indices is not None and not qc_results['train_test_pure']: + warnings.append("Train/test split is not pure!") + if qc_results['total_pairs_used'] == 0: + warnings.append("No pairs were used for reliability calculation!") + + if warnings: + for w in warnings: + print(f" ⚠ {w}") + else: + print(" ✓ All checks passed") + + print("="*50 + "\n") + + return qc_results + + +def compare_reliability_methods(vox, pairs, train_indices, test_indices=None): + """ + Compare old (first 2 repeats only) vs new (all repeats) reliability calculation. + + Args: + vox: voxel data + pairs: list of repeated image indices + train_indices: training set indices + test_indices: optional test set indices + + Returns: + dict with comparison results + """ + print("Computing reliability using OLD method (first 2 repeats only)...") + # Old method: only use first 2 repeats, no train/test filtering + pairs_homog = np.array([[p[0], p[1]] for p in pairs]) + vox_pairs = zscore(vox[pairs_homog]) + rels_old = np.full(vox.shape[-1], np.nan) + for v in range(vox.shape[-1]): + rels_old[v] = np.corrcoef(vox_pairs[:,0,v], vox_pairs[:,1,v])[1,0] + + print("Computing reliability using NEW method (all repeats, train-only)...") + rels_new, train_pairs = compute_reliability_with_train_test_split( + vox, pairs, train_indices, test_indices, verbose=False + ) + + # Compare + comparison = {} + comparison['old_mean'] = np.nanmean(rels_old) + comparison['new_mean'] = np.nanmean(rels_new) + comparison['old_std'] = np.nanstd(rels_old) + comparison['new_std'] = np.nanstd(rels_new) + comparison['correlation'] = np.corrcoef(rels_old[~np.isnan(rels_old) & ~np.isnan(rels_new)], + rels_new[~np.isnan(rels_old) & ~np.isnan(rels_new)])[0,1] + comparison['mean_diff'] = np.nanmean(rels_new - rels_old) + + print("\n" + "="*50) + print("METHOD COMPARISON") + print("="*50) + print(f"Old method (first 2 repeats, all pairs):") + print(f" Mean: {comparison['old_mean']:.4f}") + print(f" Std: {comparison['old_std']:.4f}") + print(f"\nNew method (all repeats, train-only pairs):") + print(f" Mean: {comparison['new_mean']:.4f}") + print(f" Std: {comparison['new_std']:.4f}") + print(f"\nDifference:") + print(f" Mean difference: {comparison['mean_diff']:.4f}") + print(f" Correlation: {comparison['correlation']:.4f}") + print("="*50 + "\n") + + return comparison, rels_old, rels_new + + def get_pairs(data, repeat_indices=(0, 1)): """ Extract pairs based on specified repeat indices, falling back to available repeats.