Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/modalsheaf/applications/neuro.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,79 @@ def get_section_at_time(self, t: int) -> np.ndarray:
"""Get the signal vector across all regions at timepoint t."""
return self.time_series[t, :]

# ==================== Signal Validation ====================

def validate_signal_assumptions(self) -> bool:
"""
Check if signals are properly normalized for dissonance detection.

Since dissonance is calculated as Euclidean distance (s_j - s_i),
regions with vastly different variances will generate false dissonance
dominated by high-variance regions rather than true topological structure.

This method checks for:
1. Variance disparity > 10x between regions (suggests z-scoring needed)
2. Mean disparity (signals should be centered)

Returns:
True if signals appear properly normalized, False otherwise

Raises:
UserWarning if high variance disparity is detected
"""
import warnings

variances = np.var(self.time_series, axis=0)
means = np.mean(self.time_series, axis=0)

mean_var = np.mean(variances)
is_valid = True

# Check variance disparity
if mean_var > 0:
high_var = np.any(variances > mean_var * 10)
low_var = np.any(variances < mean_var * 0.1)

if high_var or low_var:
high_var_regions = [
self.regions[i].label
for i in np.where(variances > mean_var * 10)[0]
][:3]
low_var_regions = [
self.regions[i].label
for i in np.where(variances < mean_var * 0.1)[0]
][:3]

warnings.warn(
f"High variance disparity between brain regions detected. "
f"High-variance regions: {high_var_regions}, "
f"Low-variance regions: {low_var_regions}. "
f"Dissonance metric may be dominated by high-variance regions. "
f"Ensure data is z-scored (standardize=True in load_fmri_data).",
UserWarning
)
is_valid = False

# Check if means are approximately centered
mean_std = np.std(means)
if mean_std > 1.0:
warnings.warn(
f"Signal means vary significantly across regions (std={mean_std:.2f}). "
f"Consider centering each region's time series for fair comparison.",
UserWarning
)
is_valid = False

return is_valid

# ==================== Coboundary / Dissonance Detection ====================

def detect_dissonance(
self,
t: int,
normalize: bool = True,
top_k: int = 10
top_k: int = 10,
validate_first: bool = True
) -> DissonanceResult:
"""
Detect cognitive dissonance at a specific timepoint.
Expand All @@ -591,16 +657,27 @@ def detect_dissonance(
t: Timepoint to analyze
normalize: Whether to normalize by edge count (default: True)
top_k: Number of top dissonant edges to report (default: 10)
validate_first: If True, validate signal assumptions on first call (default: True)

Returns:
DissonanceResult with metric and edge-wise analysis

Note:
The coboundary calculation (s_j - s_i) assumes signals are z-scored.
If regions have vastly different variances, high-variance regions
will dominate the dissonance metric. Use validate_signal_assumptions()
to check this, or pass validate_first=True (default).
"""
if self._graph is None:
raise RuntimeError("Call build_complex() first")

if t < 0 or t >= self.n_timepoints:
raise ValueError(f"Timepoint {t} out of range [0, {self.n_timepoints})")

# Validate signal assumptions on first call (t=0 or first unique call)
if validate_first and t == 0:
self.validate_signal_assumptions()

# Get signal at timepoint t
signal = self.get_section_at_time(t)

Expand Down
87 changes: 69 additions & 18 deletions src/modalsheaf/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,21 +319,29 @@ def compute_sheaf_laplacian(
embeddings: Dict[str, np.ndarray]
) -> np.ndarray:
"""
Compute the sheaf Laplacian matrix.
Compute the sheaf Laplacian matrix L = δᵀδ.

The sheaf Laplacian generalizes the graph Laplacian to account
for the restriction maps (transformations) between modalities.
Unlike a standard graph Laplacian that uses -I for off-diagonal blocks,
this implementation uses the actual restriction map matrices -R_{ij}
when available, enabling detection of "twist" inconsistencies
(e.g., rotations, calibration errors).

Mathematical Background:

For a cellular sheaf on a graph G = (V, E):

1. The coboundary operator δ: C⁰ → C¹ measures disagreement:
(δx)_e = F_{e←v}(x_v) - F_{e←u}(x_u)
(δx)_e = R_{e←v}(x_v) - R_{e←u}(x_u)

2. The Laplacian is L = δᵀδ (or δδᵀ for the dual)

3. Key properties:
3. For edge (i,j) with restriction map R_{ij}:
- Off-diagonal block: L_{ij} = -R_{ij}
- Diagonal block: L_{ii} = Σ_k R_{ik}ᵀ R_{ik}

4. Key properties:
- L is positive semi-definite
- ker(L) = H⁰ (global sections / consensus states)
- xᵀLx = total squared disagreement
Expand All @@ -349,6 +357,10 @@ def compute_sheaf_laplacian(
Example: 3 temperature sensors in a triangle
- ker(L) = span{[1,1,1]} = "all same temperature"
- Other eigenvectors = patterns of disagreement

Example: Sensors with rotation calibration
- If sensor B reads rotated values relative to A, R_{AB} ≠ I
- The Laplacian captures this calibration mismatch

Args:
graph: The modality graph (defines the cellular sheaf structure)
Expand All @@ -361,9 +373,11 @@ def compute_sheaf_laplacian(
>>> graph = ModalityGraph()
>>> graph.add_modality("A")
>>> graph.add_modality("B")
>>> graph.add_transformation("A", "B", forward=lambda x: x)
>>> # Rotation by 90 degrees
>>> R = np.array([[0, -1], [1, 0]])
>>> graph.add_transformation("A", "B", forward=lambda x: R @ x, matrix=R)
>>>
>>> embeddings = {"A": np.array([1.0]), "B": np.array([2.0])}
>>> embeddings = {"A": np.array([1.0, 0.0]), "B": np.array([0.0, 1.0])}
>>> L = compute_sheaf_laplacian(graph, embeddings)
>>>
>>> # Kernel dimension = H⁰ dimension
Expand All @@ -379,26 +393,63 @@ def compute_sheaf_laplacian(
# Build block Laplacian
L = np.zeros((n * d, n * d))

# First pass: compute off-diagonal blocks and accumulate diagonal contributions
for i, mod_i in enumerate(modalities):
degree_block = np.zeros((d, d))

for j, mod_j in enumerate(modalities):
if i == j:
# Diagonal: sum of incident edge contributions
degree = 0
for k, mod_k in enumerate(modalities):
if graph.get_transformation(mod_i, mod_k) is not None:
degree += 1
L[i*d:(i+1)*d, i*d:(i+1)*d] = degree * np.eye(d)
else:
# Off-diagonal: -R^T R for restriction map R
transform = graph.get_transformation(mod_i, mod_j)
if transform is not None:
# For linear transforms, we'd use the matrix
# For now, approximate with identity
L[i*d:(i+1)*d, j*d:(j+1)*d] = -np.eye(d)
continue

transform = graph.get_transformation(mod_i, mod_j)
if transform is not None:
# Get restriction map matrix R
R = _get_restriction_matrix(transform, d)

# Off-diagonal block: -R
L[i*d:(i+1)*d, j*d:(j+1)*d] = -R

# Accumulate diagonal contribution: R^T @ R
degree_block += R.T @ R

# Diagonal block: sum of R^T @ R for all incident edges
L[i*d:(i+1)*d, i*d:(i+1)*d] = degree_block

return L


def _get_restriction_matrix(transform: Transformation, d: int) -> np.ndarray:
"""
Extract the restriction map matrix from a Transformation.

Priority:
1. Use transform.matrix if explicitly provided
2. Fallback to identity matrix (standard graph Laplacian behavior)

Args:
transform: The Transformation object
d: Expected dimension of the matrix

Returns:
d×d restriction map matrix
"""
# Check if transform has an explicit matrix representation
if hasattr(transform, 'matrix') and transform.matrix is not None:
R = np.asarray(transform.matrix)
# Validate shape
if R.shape == (d, d):
return R
# If shape mismatch, warn and fallback
import warnings
warnings.warn(
f"Transformation '{transform.name}' has matrix of shape {R.shape}, "
f"expected ({d}, {d}). Falling back to identity matrix."
)

# Fallback: identity matrix (reduces to standard graph Laplacian)
return np.eye(d)


def diffuse_to_consensus(
graph: ModalityGraph,
embeddings: Dict[str, np.ndarray],
Expand Down
1 change: 1 addition & 0 deletions src/modalsheaf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class Transformation:
info_loss_estimate: float = 0.5 # 0.0 = no loss, 1.0 = total loss
name: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
matrix: Optional[np.ndarray] = None # Linear restriction map matrix R for sheaf Laplacian

def __post_init__(self):
if self.name is None:
Expand Down
40 changes: 24 additions & 16 deletions src/modalsheaf/modalities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def text_to_tokens(
max_length: Optional[int] = None,
) -> np.ndarray:
"""
Extension map: text → tokens
Restriction map: text (Global) → tokens (Local)

Converts text string to token IDs.
Extracts local constituent parts (tokens) from the global object (text).
Uses simple whitespace tokenization if no tokenizer provided.
"""
if tokenizer is not None:
Expand All @@ -54,9 +54,9 @@ def tokens_to_text(
tokenizer: Optional[Any] = None,
) -> str:
"""
Restriction map: tokens → text
Gluing/Extension map: tokens (Local) → text (Global)

Converts token IDs back to text string.
Assembles local parts (tokens) into a global structure (text).
"""
tokens = np.asarray(tokens).flatten().tolist()

Expand All @@ -70,9 +70,9 @@ def tokens_to_text(

def text_to_sentences(text: str) -> List[str]:
"""
Restriction map: text → sentences
Restriction map: text (Global) → sentences (Local)

Split text into sentences.
Restricts the global text to local open sets (sentences).
"""
import re
# Simple sentence splitting
Expand All @@ -82,37 +82,45 @@ def text_to_sentences(text: str) -> List[str]:

def sentences_to_text(sentences: List[str]) -> str:
"""
Extension map: sentences → text
Gluing/Extension map: sentences (Local) → text (Global)

Join sentences back into text.
Assembles local parts (sentences) into a global structure.
"""
return ' '.join(sentences)


def text_to_words(text: str) -> List[str]:
"""
Restriction map: text → words
Restriction map: text (Global) → words (Local)

Restricts the global text to local constituents (words).
"""
return text.split()


def words_to_text(words: List[str]) -> str:
"""
Extension map: words → text
Gluing/Extension map: words (Local) → text (Global)

Assembles local parts (words) into a global structure.
"""
return ' '.join(words)


def text_to_chars(text: str) -> List[str]:
"""
Restriction map: text → characters
Restriction map: text (Global) → characters (Local)

Restricts the global text to its finest local constituents.
"""
return list(text)


def chars_to_text(chars: List[str]) -> str:
"""
Extension map: characters → text
Gluing/Extension map: characters (Local) → text (Global)

Assembles local parts (characters) into a global structure.
"""
return ''.join(chars)

Expand All @@ -124,9 +132,9 @@ def image_to_patches(
patch_size: int = 16,
) -> np.ndarray:
"""
Restriction map: image → patches
Restriction map: image (Global) → patches (Local)

Split image into non-overlapping patches.
Restricts the global image to local open sets (patches).
Returns shape (num_patches, patch_size, patch_size, channels)
"""
image = np.asarray(image)
Expand Down Expand Up @@ -158,9 +166,9 @@ def patches_to_image(
image_shape: Tuple[int, int],
) -> np.ndarray:
"""
Extension map: patches → image
Gluing/Extension map: patches (Local) → image (Global)

Reassemble patches into image.
Assembles local parts (patches) into a global structure (image).
"""
patches = np.asarray(patches)
num_patches, patch_size, _, C = patches.shape
Expand Down