From ee73f3b51b9e621721c8867ce0ef543f9443c39e Mon Sep 17 00:00:00 2001 From: Mike Date: Wed, 7 Jan 2026 13:42:28 -0500 Subject: [PATCH] Fix sheaf-theoretic rigor: Laplacian restriction maps, terminology alignment, and neuro validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add matrix property to Transformation class for explicit restriction map representation - Fix compute_sheaf_laplacian to use actual restriction maps R instead of identity matrices * Off-diagonal blocks: L_ij = -R_ij (not -I) * Diagonal blocks: L_ii = Σ_k R_ik^T @ R_ik * Enables detection of twist inconsistencies (rotations, calibrations) - Correct docstring terminology in transforms.py * text_to_tokens, image_to_patches: Restriction maps (Global → Local) * tokens_to_text, patches_to_image: Gluing/Extension maps (Local → Global) - Add validate_signal_assumptions() to BrainSheaf * Checks variance disparity (>10x warns about z-scoring needed) * Checks mean centering across regions * Auto-validates on first detect_dissonance call * Prevents false dissonance from high-variance regions Addresses conceptual gaps identified in math rigor review. --- src/modalsheaf/applications/neuro.py | 79 +++++++++++++++++++++- src/modalsheaf/consistency.py | 87 ++++++++++++++++++++----- src/modalsheaf/core.py | 1 + src/modalsheaf/modalities/transforms.py | 40 +++++++----- 4 files changed, 172 insertions(+), 35 deletions(-) diff --git a/src/modalsheaf/applications/neuro.py b/src/modalsheaf/applications/neuro.py index 1b2f7c3..22f0ab3 100644 --- a/src/modalsheaf/applications/neuro.py +++ b/src/modalsheaf/applications/neuro.py @@ -567,13 +567,79 @@ def get_section_at_time(self, t: int) -> np.ndarray: """Get the signal vector across all regions at timepoint t.""" return self.time_series[t, :] + # ==================== Signal Validation ==================== + + def validate_signal_assumptions(self) -> bool: + """ + Check if signals are properly normalized for dissonance detection. + + Since dissonance is calculated as Euclidean distance (s_j - s_i), + regions with vastly different variances will generate false dissonance + dominated by high-variance regions rather than true topological structure. + + This method checks for: + 1. Variance disparity > 10x between regions (suggests z-scoring needed) + 2. Mean disparity (signals should be centered) + + Returns: + True if signals appear properly normalized, False otherwise + + Raises: + UserWarning if high variance disparity is detected + """ + import warnings + + variances = np.var(self.time_series, axis=0) + means = np.mean(self.time_series, axis=0) + + mean_var = np.mean(variances) + is_valid = True + + # Check variance disparity + if mean_var > 0: + high_var = np.any(variances > mean_var * 10) + low_var = np.any(variances < mean_var * 0.1) + + if high_var or low_var: + high_var_regions = [ + self.regions[i].label + for i in np.where(variances > mean_var * 10)[0] + ][:3] + low_var_regions = [ + self.regions[i].label + for i in np.where(variances < mean_var * 0.1)[0] + ][:3] + + warnings.warn( + f"High variance disparity between brain regions detected. " + f"High-variance regions: {high_var_regions}, " + f"Low-variance regions: {low_var_regions}. " + f"Dissonance metric may be dominated by high-variance regions. " + f"Ensure data is z-scored (standardize=True in load_fmri_data).", + UserWarning + ) + is_valid = False + + # Check if means are approximately centered + mean_std = np.std(means) + if mean_std > 1.0: + warnings.warn( + f"Signal means vary significantly across regions (std={mean_std:.2f}). " + f"Consider centering each region's time series for fair comparison.", + UserWarning + ) + is_valid = False + + return is_valid + # ==================== Coboundary / Dissonance Detection ==================== def detect_dissonance( self, t: int, normalize: bool = True, - top_k: int = 10 + top_k: int = 10, + validate_first: bool = True ) -> DissonanceResult: """ Detect cognitive dissonance at a specific timepoint. @@ -591,9 +657,16 @@ def detect_dissonance( t: Timepoint to analyze normalize: Whether to normalize by edge count (default: True) top_k: Number of top dissonant edges to report (default: 10) + validate_first: If True, validate signal assumptions on first call (default: True) Returns: DissonanceResult with metric and edge-wise analysis + + Note: + The coboundary calculation (s_j - s_i) assumes signals are z-scored. + If regions have vastly different variances, high-variance regions + will dominate the dissonance metric. Use validate_signal_assumptions() + to check this, or pass validate_first=True (default). """ if self._graph is None: raise RuntimeError("Call build_complex() first") @@ -601,6 +674,10 @@ def detect_dissonance( if t < 0 or t >= self.n_timepoints: raise ValueError(f"Timepoint {t} out of range [0, {self.n_timepoints})") + # Validate signal assumptions on first call (t=0 or first unique call) + if validate_first and t == 0: + self.validate_signal_assumptions() + # Get signal at timepoint t signal = self.get_section_at_time(t) diff --git a/src/modalsheaf/consistency.py b/src/modalsheaf/consistency.py index 43d5cf4..e6e992d 100644 --- a/src/modalsheaf/consistency.py +++ b/src/modalsheaf/consistency.py @@ -319,21 +319,29 @@ def compute_sheaf_laplacian( embeddings: Dict[str, np.ndarray] ) -> np.ndarray: """ - Compute the sheaf Laplacian matrix. + Compute the sheaf Laplacian matrix L = δᵀδ. The sheaf Laplacian generalizes the graph Laplacian to account for the restriction maps (transformations) between modalities. + Unlike a standard graph Laplacian that uses -I for off-diagonal blocks, + this implementation uses the actual restriction map matrices -R_{ij} + when available, enabling detection of "twist" inconsistencies + (e.g., rotations, calibration errors). Mathematical Background: For a cellular sheaf on a graph G = (V, E): 1. The coboundary operator δ: C⁰ → C¹ measures disagreement: - (δx)_e = F_{e←v}(x_v) - F_{e←u}(x_u) + (δx)_e = R_{e←v}(x_v) - R_{e←u}(x_u) 2. The Laplacian is L = δᵀδ (or δδᵀ for the dual) - 3. Key properties: + 3. For edge (i,j) with restriction map R_{ij}: + - Off-diagonal block: L_{ij} = -R_{ij} + - Diagonal block: L_{ii} = Σ_k R_{ik}ᵀ R_{ik} + + 4. Key properties: - L is positive semi-definite - ker(L) = H⁰ (global sections / consensus states) - xᵀLx = total squared disagreement @@ -349,6 +357,10 @@ def compute_sheaf_laplacian( Example: 3 temperature sensors in a triangle - ker(L) = span{[1,1,1]} = "all same temperature" - Other eigenvectors = patterns of disagreement + + Example: Sensors with rotation calibration + - If sensor B reads rotated values relative to A, R_{AB} ≠ I + - The Laplacian captures this calibration mismatch Args: graph: The modality graph (defines the cellular sheaf structure) @@ -361,9 +373,11 @@ def compute_sheaf_laplacian( >>> graph = ModalityGraph() >>> graph.add_modality("A") >>> graph.add_modality("B") - >>> graph.add_transformation("A", "B", forward=lambda x: x) + >>> # Rotation by 90 degrees + >>> R = np.array([[0, -1], [1, 0]]) + >>> graph.add_transformation("A", "B", forward=lambda x: R @ x, matrix=R) >>> - >>> embeddings = {"A": np.array([1.0]), "B": np.array([2.0])} + >>> embeddings = {"A": np.array([1.0, 0.0]), "B": np.array([0.0, 1.0])} >>> L = compute_sheaf_laplacian(graph, embeddings) >>> >>> # Kernel dimension = H⁰ dimension @@ -379,26 +393,63 @@ def compute_sheaf_laplacian( # Build block Laplacian L = np.zeros((n * d, n * d)) + # First pass: compute off-diagonal blocks and accumulate diagonal contributions for i, mod_i in enumerate(modalities): + degree_block = np.zeros((d, d)) + for j, mod_j in enumerate(modalities): if i == j: - # Diagonal: sum of incident edge contributions - degree = 0 - for k, mod_k in enumerate(modalities): - if graph.get_transformation(mod_i, mod_k) is not None: - degree += 1 - L[i*d:(i+1)*d, i*d:(i+1)*d] = degree * np.eye(d) - else: - # Off-diagonal: -R^T R for restriction map R - transform = graph.get_transformation(mod_i, mod_j) - if transform is not None: - # For linear transforms, we'd use the matrix - # For now, approximate with identity - L[i*d:(i+1)*d, j*d:(j+1)*d] = -np.eye(d) + continue + + transform = graph.get_transformation(mod_i, mod_j) + if transform is not None: + # Get restriction map matrix R + R = _get_restriction_matrix(transform, d) + + # Off-diagonal block: -R + L[i*d:(i+1)*d, j*d:(j+1)*d] = -R + + # Accumulate diagonal contribution: R^T @ R + degree_block += R.T @ R + + # Diagonal block: sum of R^T @ R for all incident edges + L[i*d:(i+1)*d, i*d:(i+1)*d] = degree_block return L +def _get_restriction_matrix(transform: Transformation, d: int) -> np.ndarray: + """ + Extract the restriction map matrix from a Transformation. + + Priority: + 1. Use transform.matrix if explicitly provided + 2. Fallback to identity matrix (standard graph Laplacian behavior) + + Args: + transform: The Transformation object + d: Expected dimension of the matrix + + Returns: + d×d restriction map matrix + """ + # Check if transform has an explicit matrix representation + if hasattr(transform, 'matrix') and transform.matrix is not None: + R = np.asarray(transform.matrix) + # Validate shape + if R.shape == (d, d): + return R + # If shape mismatch, warn and fallback + import warnings + warnings.warn( + f"Transformation '{transform.name}' has matrix of shape {R.shape}, " + f"expected ({d}, {d}). Falling back to identity matrix." + ) + + # Fallback: identity matrix (reduces to standard graph Laplacian) + return np.eye(d) + + def diffuse_to_consensus( graph: ModalityGraph, embeddings: Dict[str, np.ndarray], diff --git a/src/modalsheaf/core.py b/src/modalsheaf/core.py index 42fa56c..6fcda51 100644 --- a/src/modalsheaf/core.py +++ b/src/modalsheaf/core.py @@ -126,6 +126,7 @@ class Transformation: info_loss_estimate: float = 0.5 # 0.0 = no loss, 1.0 = total loss name: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) + matrix: Optional[np.ndarray] = None # Linear restriction map matrix R for sheaf Laplacian def __post_init__(self): if self.name is None: diff --git a/src/modalsheaf/modalities/transforms.py b/src/modalsheaf/modalities/transforms.py index 49c8cd4..e285d43 100644 --- a/src/modalsheaf/modalities/transforms.py +++ b/src/modalsheaf/modalities/transforms.py @@ -29,9 +29,9 @@ def text_to_tokens( max_length: Optional[int] = None, ) -> np.ndarray: """ - Extension map: text → tokens + Restriction map: text (Global) → tokens (Local) - Converts text string to token IDs. + Extracts local constituent parts (tokens) from the global object (text). Uses simple whitespace tokenization if no tokenizer provided. """ if tokenizer is not None: @@ -54,9 +54,9 @@ def tokens_to_text( tokenizer: Optional[Any] = None, ) -> str: """ - Restriction map: tokens → text + Gluing/Extension map: tokens (Local) → text (Global) - Converts token IDs back to text string. + Assembles local parts (tokens) into a global structure (text). """ tokens = np.asarray(tokens).flatten().tolist() @@ -70,9 +70,9 @@ def tokens_to_text( def text_to_sentences(text: str) -> List[str]: """ - Restriction map: text → sentences + Restriction map: text (Global) → sentences (Local) - Split text into sentences. + Restricts the global text to local open sets (sentences). """ import re # Simple sentence splitting @@ -82,37 +82,45 @@ def text_to_sentences(text: str) -> List[str]: def sentences_to_text(sentences: List[str]) -> str: """ - Extension map: sentences → text + Gluing/Extension map: sentences (Local) → text (Global) - Join sentences back into text. + Assembles local parts (sentences) into a global structure. """ return ' '.join(sentences) def text_to_words(text: str) -> List[str]: """ - Restriction map: text → words + Restriction map: text (Global) → words (Local) + + Restricts the global text to local constituents (words). """ return text.split() def words_to_text(words: List[str]) -> str: """ - Extension map: words → text + Gluing/Extension map: words (Local) → text (Global) + + Assembles local parts (words) into a global structure. """ return ' '.join(words) def text_to_chars(text: str) -> List[str]: """ - Restriction map: text → characters + Restriction map: text (Global) → characters (Local) + + Restricts the global text to its finest local constituents. """ return list(text) def chars_to_text(chars: List[str]) -> str: """ - Extension map: characters → text + Gluing/Extension map: characters (Local) → text (Global) + + Assembles local parts (characters) into a global structure. """ return ''.join(chars) @@ -124,9 +132,9 @@ def image_to_patches( patch_size: int = 16, ) -> np.ndarray: """ - Restriction map: image → patches + Restriction map: image (Global) → patches (Local) - Split image into non-overlapping patches. + Restricts the global image to local open sets (patches). Returns shape (num_patches, patch_size, patch_size, channels) """ image = np.asarray(image) @@ -158,9 +166,9 @@ def patches_to_image( image_shape: Tuple[int, int], ) -> np.ndarray: """ - Extension map: patches → image + Gluing/Extension map: patches (Local) → image (Global) - Reassemble patches into image. + Assembles local parts (patches) into a global structure (image). """ patches = np.asarray(patches) num_patches, patch_size, _, C = patches.shape