diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..c25fab4 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,17 @@ +# Global owners (catch-all for all files) +* @BenjaminIsaac0111 + +# Core Model and Training Logic +src/spatial_transcript_former/models/ @BenjaminIsaac0111 +src/spatial_transcript_former/training/ @BenjaminIsaac0111 + +# Data Management and Scripts +src/spatial_transcript_former/data/ @BenjaminIsaac0111 +scripts/ @BenjaminIsaac0111 + +# Documentation +docs/ @BenjaminIsaac0111 +*.md @BenjaminIsaac0111 + +# GitHub Actions and Infrastructure +.github/ @BenjaminIsaac0111 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bcebc17 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# Contributing to SpatialTranscriptFormer + +Thank you for your interest in contributing! As a project at the intersection of deep learning and pathology, we value rigorous, well-tested contributions. + +## Project Status + +> [!IMPORTANT] +> This project is a **Work in Progress**. We are actively refining the core interaction logic and scaling behaviors. Expect breaking changes in the CLI and data schemas. + +## Intellectual Property & Licensing + +SpatialTranscriptFormer is protected under a **Proprietary Source Code License**. + +- **Academic/Non-Profit**: We encourage contributions from the research community. Contributions made under an academic affiliation are generally welcome. +- **Commercial/For-Profit**: Contributions from commercial entities or individuals intended for profit-seeking use require a separate agreement. +- **Assignment**: By submitting a Pull Request, you agree that your contributions will be licensed under the project's existing license, granting the author the right to include them in both the open-access and proprietary versions of the software. + +## Development Workflow + +### 1. Environment Setup + +Use the provided setup scripts to ensure a consistent development environment: + +```bash +# Windows +.\setup.ps1 + +# Linux/HPC +bash setup.sh +``` + +### 2. Coding Standards + +We use `black` for formatting and `flake8` for linting. Please ensure your code passes these checks before submitting. + +```bash +black . +flake8 src/ +``` + +### 3. Testing + +All new features must include unit tests in the `tests/` directory. We use `pytest` for our test suite. + +```bash +# Run all tests +.\test.ps1 # Windows +bash test.sh # Linux +``` + +## Pull Request Process + +1. **Open an Issue**: For major changes, please open an issue first to discuss the design. +2. **Branching**: Work on a descriptive feature branch (e.g., `feature/pathway-attention-mask`). +3. **Documentation**: Update relevant files in `docs/` and the `README.md` if your change affects usage. +4. **Verification**: Ensure all CI checks (GitHub Actions) pass. + +### Branch Protections + +To maintain code quality and stability, the following protections are enforced on the `main` branch: + +- **Require Pull Request Reviews**: All merges to `main` require at least one approval from a project maintainer. +- **Required Status Checks**: The `CI` workflow must pass successfully before a PR can be merged. This includes formatting checks (`black`) and the full test suite (`pytest`). +- **No Direct Pushes**: Pushing directly to `main` is disabled. All changes must go through the Pull Request process. +- **Linear History**: We prefer **Squash and Merge** to keep the `main` branch history clean and concise. + +## Contact + +For questions regarding commercial licensing or complex architectural changes, please contact the author directly. diff --git a/README.md b/README.md index c2eada9..f7884ff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ # SpatialTranscriptFormer -A transformer-based model for spatial transcriptomics. +> [!WARNING] +> **Work in Progress**: This project is under active development. Core architectures, CLI flags, and data formats are subject to major changes. + +A transformer-based model for spatial transcriptomics that bridges histology and biological pathways. + +## Key Features + +- **Quad-Flow Interaction**: Configurable attention between Pathways and Histology patches (`p2p`, `p2h`, `h2p`, `h2h`). +- **Pathway Bottleneck**: Interpretable gene expression prediction via 50 MSigDB Hallmark tokens. +- **Spatial Pattern Coherence**: Optimized using a composite **MSE + PCC (Pearson Correlation) loss** to prevent spatial collapse and ensure accurate morphology-expression mapping. +- **Biologically Informed Initialization**: Gene reconstruction weights derived from known hallmark memberships. ## License @@ -25,71 +35,66 @@ This project requires [Conda](https://docs.conda.io/en/latest/). ## Usage -After installation, the following command-line tools are available in your `SpatialTranscriptFormer` environment: - ### Download HEST Data Download specific subsets using filters or patterns: ```bash -# List available organs -stf-download --list_organs - # Download only the Bowel Cancer subset (including ST data and WSIs) stf-download --organ Bowel --disease Cancer --local_dir hest_data - -# Download any other organ -stf-download --organ Kidney ``` -### Split Dataset +### Train Models + +We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer. -Perform patient-stratified splitting on the metadata: +```bash +# Recommended: Run the Interaction model with 4 transformer layers +python scripts/run_preset.py --preset stf_interaction_l4 -```powershell -stf-split HEST_v1_3_0.csv --val_ratio 0.2 +# Run the lightweight 2-layer version +python scripts/run_preset.py --preset stf_interaction_l2 + +# Run baselines +python scripts/run_preset.py --preset he2rna_baseline ``` -### Train Models +For a complete list of configurations, see the [Training Guide](docs/TRAINING_GUIDE.md). -Train baseline models (HE2RNA, ViT) or the proposed interaction model. For a complete list of configurations and examples, see the [Training Guide](docs/TRAINING_GUIDE.md). +### Real-Time Monitoring -```bash -# Option 1: Using the standard command -stf-train --data-dir A:\hest_data --model he2rna --epochs 20 +Monitor training progress, loss curves, and **prediction variance (collapse detector)** via the web dashboard: -# Option 2: Using the preset launcher (recommended for complex models) -python scripts/run_preset.py --preset stf_interaction --epochs 30 +```bash +python scripts/monitor.py --run-dir runs/stf_interaction_l4 ``` ### Inference & Visualization -Generate spatial maps comparing Ground Truth vs Predictions for specific samples: +Generate spatial maps comparing Ground Truth vs Predictions: ```bash -stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model_he2rna.pth --model-type he2rna +stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model.pth --model-type interaction ``` Visualization plots will be saved to the `./results` directory. ## Documentation -For detailed information on the data and code implementation, see: - +- [Models](docs/MODELS.md): Detailed model architectures and scaling parameters. - [Data Structure](docs/DATA_STRUCTURE.md): Organization of HEST data on disk. -- [Dataloader](docs/DATALOADER.md): Technical implementation of the PyTorch dataset and loaders. -- [Gene Analysis](docs/GENE_ANALYSIS.md): Analysis of available genes and modeling strategies. -- [Pathway Mapping](docs/PATHWAY_MAPPING.md): Strategies for clinical interpretability and pathway integration. -- [Latent Discovery](docs/LATENT_DISCOVERY.md): Unsupervised discovery of biological pathways from data. -- [Models](docs/MODELS.md): Model architectures and literature references. +- [Pathway Mapping](docs/PATHWAY_MAPPING.md): Clinical interpretability and pathway integration. +- [Gene Analysis](docs/GENE_ANALYSIS.md): Modeling strategies for high-dimensional gene space. ## Development ### Running Tests -Use the included test wrapper: - ```bash -# Run all tests +# Run all tests (Pytest wrapper) .\test.ps1 ``` + +## Contributing + +We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on our coding standards and the process for submitting pull requests. Note that this project is under a proprietary license; contributions involve an assignment of rights for non-academic use. diff --git a/config.yaml b/config.yaml index 19384fd..ae788e3 100644 --- a/config.yaml +++ b/config.yaml @@ -4,15 +4,12 @@ # Data Paths # Candidates for the HEST data directory (checked in order) data_dirs: - - "hest_data" - - "../hest_data" - - "./data" - "A:\\hest_data" # Training Defaults training: num_genes: 1000 - batch_size: 32 + batch_size: 8 learning_rate: 0.0001 output_dir: "./checkpoints" diff --git a/docs/MODELS.md b/docs/MODELS.md index a7e7648..0b9dc54 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -1,194 +1,249 @@ -# Model Zoo and Literature References +# SpatialTranscriptFormer: Architecture & Design -This document provides a summary of the models implemented in this project, their origins in literature, and the architectural details of the **SpatialTranscriptFormer**. - -## 1. Regression Models (Patch-Level) - -### HE2RNA (ResNet-50) - -- **Reference**: [Schmauch et al. (2020). "A deep learning model to predict RNA-Seq expression of tumours from whole slide images." Nature Communications.](https://www.nature.com/articles/s41467-020-17679-z) -- **Description**: Uses a ResNet-50 backbone to extract features from histology patches and directly regresses a high-dimensional gene expression vector. -- **Equation**: - $$\hat{y} = \text{FC}(\text{ResNet50}(x))$$ - Where $x$ is the histology patch and $\hat{y} \in \mathbb{R}^G$ is the predicted expression for $G$ genes. - -### ViT-ST - -- **Description**: An adaptation of the Vision Transformer (ViT) to the spatial transcriptomics task by replacing the classification head with a regression head. -- **Equation**: - $$\hat{y} = \text{MLP}(\text{TransformerEncoder}(x_{tokens}))$$ +This document describes the architecture, design philosophy, and training objectives of the **SpatialTranscriptFormer** model, along with the baseline models implemented for comparison. --- -## 2. Multiple Instance Learning (Slide-Level) - -### Attention-MIL - -- **Reference**: [Ilse et al. (2018). "Attention-based Deep Multiple Instance Learning." ICML.](https://arxiv.org/abs/1802.04712) -- **Description**: Learns attention weights for individual patches to aggregate them into a slide-level representation. -- **Aggregation**: - $$z = \sum_{i=1}^N a_i h_i, \quad a_i = \frac{\exp(\mathbf{w}^\top \tanh(\mathbf{V} h_i^\top) \odot \text{sigm}(\mathbf{U} h_i^\top))}{\sum_{j=1}^N \exp(\mathbf{w}^\top \tanh(\mathbf{V} h_j^\top) \odot \text{sigm}(\mathbf{U} h_j^\top))}$$ - -### TransMIL - -- **Reference**: [Shao et al. (2021). "TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification." NeurIPS.](https://arxiv.org/abs/2106.00908) -- **Description**: Uses a Nyström-based linear transformer to capture correlations between patches across the entire slide. - -### Weak Supervision for MIL - -MIL models are trained with **bag-level supervision**: the model predicts a single slide-level expression vector and is supervised against the mean expression across all valid spots. - -Given a slide with $N$ valid spots and gene expression matrix $\mathbf{Y} \in \mathbb{R}^{N \times G}$, the **bag-level target** is: -$$\bar{y}_g = \frac{1}{N} \sum_{i=1}^N y_{ig}$$ - -For padded batches with mask $\mathbf{m} \in \{0, 1\}^{N}$, this becomes: -$$\bar{y}_g = \frac{\sum_{i=1}^N (1 - m_i) \cdot y_{ig}}{\sum_{i=1}^N (1 - m_i)}$$ - -The training loss is then: -$$\mathcal{L}_{weak} = \frac{1}{G} \sum_{g=1}^G (\hat{y}_g - \bar{y}_g)^2$$ +## 1. Problem Statement -### Spatial Attention Correlation +**Goal**: Predict spatially-resolved gene expression from histology images. -To evaluate whether MIL attention maps recover spatial localisation of gene activity, we compute the **Pearson correlation** between attention weights and total gene expression at each spot: -$$\rho = \text{corr}\left(\mathbf{a},\ \sum_{g=1}^G \mathbf{y}_g\right)$$ +**Data**: Spatial transcriptomics datasets (a subset of HEST-1k, filtered to bowel cancer from human patients) where each tissue section has: -where $\mathbf{a} \in \mathbb{R}^N$ is the vector of attention weights and $\sum_g \mathbf{y}_g \in \mathbb{R}^N$ is the total expression per spot. A high $\rho$ indicates the model has learned to attend to transcriptionally active regions. +- A whole-slide histology image (H&E) +- Per-spot gene expression counts with spatial coordinates -### Dense Supervision (Masked MSE) - -For models that support per-spot dense prediction in whole-slide mode (e.g., SpatialTranscriptFormer via `forward_dense`), we use a **masked MSE** that ignores padded positions: -$$\mathcal{L}_{dense} = \frac{\sum_{i=1}^N \sum_{g=1}^G (1 - m_i)(\hat{y}_{ig} - y_{ig})^2}{\sum_{i=1}^N (1 - m_i) \cdot G}$$ +**Challenge**: Directly predicting ~1000 genes from image patches is high-dimensional, noisy, and biologically uninterpretable. We need a structured bottleneck that compresses the gene space into biologically meaningful abstractions. --- -## 3. SpatialTranscriptFormer (Proposed Model) - -The **SpatialTranscriptFormer** (formerly Pathway Interaction Model) introduces a biologically-informed bottleneck layer using Cross-Attention between Learned Pathways and Image Features. - -### 3.1 Architecture Equations - -#### 1. Image Encoding - -Given an image patch $x$, we extract features $F$ using a backbone (e.g., ResNet or ViT): -$$F = \text{Encoder}(x), \quad F \in \mathbb{R}^{D}$$ +## 2. SpatialTranscriptFormer (Proposed Model) + +### 2.1 Design Philosophy + +The SpatialTranscriptFormer models the **interaction between biological pathways and histology** via four configurable information flows: + +1. **P↔P (Pathway self-interaction)**: Pathways refine each other's representations, capturing biological co-dependencies — e.g., EMT and inflammatory response pathways often co-activate in tumour microenvironments. + +2. **P→H (Pathway queries Histology)**: Pathway tokens query patch features with cross-attention, asking *"does this tissue region show morphological evidence of this biological process?"* — e.g., does a patch look consistent with angiogenesis or epithelial-mesenchymal transition? + +3. **H→P (Histology reads Pathways)**: Patch tokens attend to pathway tokens, receiving biological context — e.g., *"this patch is in a region where the inflammatory response pathway is highly active."* This contextualises the visual features with global biological state. + +4. **H↔H (Patch self-interaction)**: Patches attend to each other, enabling the model to capture spatial relationships between tissue regions directly. + +By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants: + +```bash +# Default: Full Interaction (all quadrants enabled) +--interactions p2p p2h h2p h2h + +# Pathway Bottleneck: block H↔H to force all inter-patch +# communication through the pathway bottleneck +--interactions p2p p2h h2p +``` + +> [!TIP] +> The **Pathway Bottleneck** variant (disabling `h2h`) is particularly useful for **interpretability** — all spatial interactions are mediated by named biological pathways — and for **anti-collapse** — preventing patches from averaging into identical representations. + +Three additional design principles support these interactions: + +- **Frozen Foundation Model Backbone** — The visual backbone (CTransPath, Phikon, etc.) is a pre-trained pathology feature extractor. It is never fine-tuned. The model learns only the pathway-histology interactions, keeping training lightweight. + +- **Dense Spatial Supervision** — Unlike weak MIL (which uses slide-level labels), we supervise at the **spot level** using spatial transcriptomics. Every patch receives ground-truth expression, enabling the model to learn spatially-resolved pathway activation patterns. + +- **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training. + +### 2.2 Spatial Learning + +The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this: + +1. **Positional Encoding** — Each patch token receives a 2D sinusoidal encoding of its (x, y) coordinate on the tissue. This means the pathway tokens, when they attend to patches, can distinguish *where* each patch is. A pathway token can learn that EMT is localised at the tumour-stroma boundary, not uniformly across the slide. + +2. **PCC Loss (Spatial Pattern Coherence)** — The Pearson Correlation component in the composite loss measures whether the *spatial pattern* of each gene's predicted expression matches the ground truth pattern, independently of scale. A model that predicts the same value everywhere scores PCC = 0, even if the mean is correct. This directly penalises spatial collapse. + +Together, these ensure the model learns *spatially-varying* pathway activation maps rather than slide-level averages. + +### 2.3 Architecture + +```text +┌──────────────────────────────────────────────────────────────────────────────┐ +│ SpatialTranscriptFormer │ +│ │ +│ ┌─────────────┐ ┌──────────────┐ ┌──────────────────────────┐ │ +│ │ Frozen │ │ Image │ │ + Spatial PE │ │ +│ │ Backbone │──>│ Projection │──>│ (2D Learned) │ │ +│ │ (CTransPath)│ │ (Linear) │ │ │ │ +│ └─────────────┘ └──────────────┘ └────────┬─────────────────┘ │ +│ │ Patch Tokens (S, D) │ +│ ┌──────────────────────────┐ │ │ +│ │ Learnable Pathway │ │ │ +│ │ Tokens (P, D) │────────┐ │ │ +│ │ (MSigDB Hallmarks) │ │ │ │ +│ └──────────────────────────┘ ▼ ▼ │ +│ ┌─────────────────────────────┐ │ +│ │ Transformer Encoder │ │ +│ │ Sequence: [Pathways|Patches]│ │ +│ │ │ │ +│ │ Full Interaction (default): │ │ +│ │ • P↔P ✅ P→H ✅ │ │ +│ │ • H→P ✅ H↔H ✅ │ │ +│ │ │ │ +│ │ Configurable via │ │ +│ │ --interactions flag │ │ +│ └──────────┬──────────────────┘ │ +│ │ │ +│ ┌──────────▼──────────────────┐ │ +│ │ Cosine Similarity Scoring │ │ +│ │ with Learnable Temperature │ │ +│ │ │ │ +│ │ scores = cos(patch, pathway)│ │ +│ │ × τ │ │ +│ └──────────┬──────────────────┘ │ +│ │ Pathway Scores (S, P) │ +│ ┌───────────┴───────────┐ │ +│ │ │ │ +│ ▼ ▼ │ +│ ┌──────────────────────┐ ┌───────────────────────────┐ │ +│ │ Gene Reconstructor │ │ Auxiliary Pathway Loss │ │ +│ │ (Linear: P → G) │ │ PCC(scores, target_pw) │ │ +│ │ Init: MSigDB │ │ weighted by λ_aux │ │ +│ └──────────┬───────────┘ └───────────┬───────────────┘ │ +│ │ │ │ +│ ▼ ▼ │ +│ Gene Expression (S, G) ℒ_aux = λ(1 − PCC) │ +│ │ │ │ +│ └──────────┬───────────────┘ │ +│ ▼ │ +│ ℒ_total = ℒ_gene + ℒ_aux │ +└──────────────────────────────────────────────────────────────────────────────┘ +``` + +### 2.3 Key Components + +#### Frozen Backbone (Feature Extraction) + +Pre-computed features from a pathology foundation model. The backbone is never fine-tuned. + +| Backbone | Feature Dim | Source | +| :--- | :--- | :--- | +| **CTransPath** | 768 | [Wang et al. (2022)](https://arxiv.org/abs/2111.13324) | +| **GigaPath** | 1536 | [Microsoft Prov-GigaPath](https://hf.co/prov-gigapath/prov-gigapath) | +| **Hibou** | 768 / 1024 | [HistAI Hibou](https://hf.co/histai) | +| **Phikon** | 768 | [Owkin Phikon](https://hf.co/owkin/phikon) | +| **ResNet-50** | 2048 | Torchvision (ImageNet) | -#### 2. Pathway Tokenization +#### Pathway Tokenizer -We initialize $K$ learnable pathway tokens $T_{path} \in \mathbb{R}^K \times D_{token}$. These tokens act as "queries" to search for relevant morphological features in the image. +Learnable embeddings $T \in \mathbb{R}^{P \times D}$ representing biological pathways. These act as [CLS]-like bottleneck tokens, analogous to the "perceiver" cross-attention pattern. -#### 3. Interaction (Unified Early Fusion) +#### Sinusoidal Positional Encoding (2D) -The interaction logic is unified via the **EarlyFusionBlock**, which facilitates a generalized attention mechanism between Pathway tokens ($P$) and Histology tokens ($H$). The model operates on a concatenated sequence $X = [P; H]$. +Patch tokens receive spatial location information via 2D sinusoidal embeddings: +$$H_{PE} = H_{proj} + \text{PE}_{2D}(x, y)$$ +This encodes the absolute position of each patch on the tissue slide, enabling the model to learn spatially-varying pathway activation patterns. -##### Quadrant Masking +#### Configurable Interaction Masking -To control the topology of the interaction and ensure scalability, we apply a mask $M$ to the attention matrix: +The transformer uses a custom attention mask that controls information flow between token groups. By default, **all interactions are enabled** (Full Interaction). You can selectively disable any combination using the `--interactions` flag: -- **P2P (Pathway Self-Interaction)**: Pathways refine each other's latent signatures. -- **P2H (Cross-Attention)**: Pathways query the histology patches for morphological evidence. -- **H2P (Biological Feedback)**: Histology features can be influenced by pathway context. -- **H2H (Global Patch-to-Patch)**: Patches attend to other distal patches. +| Quadrant | Token Flow | Description | +| :--- | :--- | :--- | +| **p2p** | Pathway ↔ Pathway | Pathways refine each other | +| **p2h** | Pathway → Histology | Pathways gather spatial info from patches | +| **h2p** | Histology → Pathway | Patches receive contextualised pathway signals | +| **h2h** | Histology ↔ Histology | Patches attend to each other directly | -> [!IMPORTANT] -> **Default Configuration**: In the standard experiment, **H2H is masked**. This removes the $O(N_H^2)$ memory bottleneck, allowing the model to handle thousands of spots simultaneously while still capturing global pathway-level correlations and local morphology. +> [!NOTE] +> Disabling `h2h` creates the **Pathway Bottleneck** variant, where all inter-patch communication must flow through the pathway tokens. This requires **minimum 2 transformer layers**: Layer 1 lets pathways gather information from patches, and Layer 2 lets patches read the contextualised pathway tokens. -### 3.2 Spatial Inductive Biases +#### Cosine Similarity Scoring -#### Sinusoidal Positional Encoding (2D) +Pathway scores are computed via L2-normalized cosine similarity with a learnable temperature parameter $\tau$ (following CLIP): +$$s_{ij} = \cos(\hat{h}_i, \hat{p}_j) \times \tau$$ +where $\hat{h}_i$ and $\hat{p}_j$ are the L2-normalized processed patch and pathway tokens respectively. This produces scores in $[-\tau, +\tau]$ with meaningful relative differences, avoiding the saturation that occurs with raw dot-products. -We inject knowledge of absolute patch locations $(x, y)$ into the transformer latent space via 2D sinusoidal embeddings: -$$H_{PE} = H_{proj} + \text{PE}_{2D}(x, y)$$ -This ensures the attention mechanism is aware of the relative distances and orientations between histology features. - -#### Local Patch Mixing (Conv) +#### Gene Reconstruction (Biologically-Informed) -Because global `H2H` is typically masked for efficiency, we use a **LocalPatchMixer** to regain immediate spatial context. This module applies a depthwise convolution over the spatial grid *before* global interaction: -$$H_{mixed} = \text{GELU}(\text{DepthwiseConv2d}(H_{grid}))$$ -This allows the model to capture high-bandwidth local morphology (e.g., 3x3 window) while the Transformer focuses on global biology-driven interactions. +A linear layer $W_{recon} \in \mathbb{R}^{G \times P}$ maps pathway scores to gene expression: +$$\hat{y}_g = \sum_{k=1}^P s_k \cdot W_{gk} + b_g$$ -### 3.3 Whole-Slide Dense Prediction (`forward_dense`) +When `--pathway-init` is enabled, $W_{recon}$ is initialised from the MSigDB Hallmark gene sets as a binary membership matrix, giving the model a biologically-grounded starting point where each pathway is connected only to its known member genes. -When predicting gene expression for every spot across a slide, the model leverages its global context differently. Instead of taking a single coordinate as the 'target', it updates all histology tokens simultaneously: +### 2.4 Training Modes -1. **Global Context**: Pathways query the entire slide (or a large masked window). -2. **Dense Head**: Updated histology tokens $H'$ are projected back to gene space: - $$\hat{y}_{dense} = H' \cdot T_{path}^\top \cdot W_{recon}$$ - This ensures that the predicted expression logic is consistent with the pathway bottleneck used during patch-level training. +| Mode | Input | Output | Supervision | +| :--- | :--- | :--- | :--- | +| **Dense (whole-slide)** | All patches from a slide | Per-patch gene predictions $(B, S, G)$ | Masked MSE+PCC at each spot | +| **Global** | All patches from a slide | Slide-level prediction $(B, G)$ | Mean-pooled expression | --- -## 4. Scalability: Nyström Approximation - -To scale effectively to whole-slide contexts or extremely large neighborhoods ($N_H > 1000$), we utilize the **Nyström method** for approximating the self-attention matrix. This reduces the complexity from quadratic $O(N^2)$ to linear $O(N \cdot m)$. - -### Theoretical Intuition +## 3. Baseline Models -The core of Transformer attention is the kernel matrix $\mathbf{K} \in \mathbb{R}^{N \times N}$. Nyström approximates $\mathbf{K}$ using a small set of **landmark points** $m \ll N$: -$$\mathbf{K} \approx \mathbf{C} \mathbf{W}^{+} \mathbf{C}^\top$$ -Where: +### HE2RNA (ResNet-50) -- $\mathbf{C} \in \mathbb{R}^{N \times m}$ contains $m$ landmark columns. -- $\mathbf{W} \in \mathbb{R}^{m \times m}$ is the intersection matrix of those columns. -- $\mathbf{W}^{+}$ is the Moore-Penrose pseudo-inverse. +- **Reference**: [Schmauch et al. (2020), Nature Communications](https://www.nature.com/articles/s41467-020-17679-z) +- Direct regression from patch features to gene expression via a single linear layer. -### Decomposition in Attention +### Attention-MIL -In the context of attention $\text{Softmax}(\frac{QK^\top}{\sqrt{d}})V$, the Nyström approximation allows us to compute the interaction without ever forming the $N \times N$ matrix: -$$\tilde{A} = \text{Softmax}\left(\frac{Q \tilde{K}^\top}{\sqrt{d}}\right) \left[ \text{Softmax}\left(\frac{\tilde{K} \tilde{K}^\top}{\sqrt{d}}\right) \right]^{+} \text{Softmax}\left(\frac{\tilde{K} K^\top}{\sqrt{d}}\right)$$ -Where $\tilde{K}$ are the pooled landmark features. +- **Reference**: [Ilse et al. (2018), ICML](https://arxiv.org/abs/1802.04712) +- Learns gated attention weights to aggregate patches into a slide-level representation. -### Why it matters for Spatial Transcriptomics +### TransMIL -1. **Whole Slide Context**: Standard transformers fail on WSIs (e.g., 20,000 patches). Nyström enables slide-level correlation (implemented in the **TransMIL** model). -2. **Dense Neighborhoods**: Allows the **SpatialTranscriptFormer** to consider hundreds of surrounding patches for a single prediction with minimal GPU memory overhead. +- **Reference**: [Shao et al. (2021), NeurIPS](https://arxiv.org/abs/2106.00908) +- Nyström-based transformer for capturing long-range patch correlations. --- -## 5. Backbone Zoo +## 4. Loss Functions -The model supports multiple state-of-the-art pathology backbones. These are selected via the CLI using the `--backbone` flag. +| `mse` | Masked MSE | Magnitude accuracy at each spot | +| `pcc` | 1 − PCC | Spatial pattern coherence per gene (scale-invariant) | +| `mse_pcc` | MSE + α(1 − PCC) | Balances absolute magnitude and spatial shape | +| `zinb` | ZINB NLL | Zero-Inflated Negative Binomial negative log-likelihood | -| Backbone | Variant | Feature Dim | Source / Reference | -| :--- | :--- | :--- | :--- | -| **ResNet** | `resnet50` | 2048 | Torchvision (ImageNet) | -| **CTransPath** | `ctranspath` | 768 | [Wang et al. (2022)](https://arxiv.org/abs/2111.13324) | -| **GigaPath** | `gigapath` | 1536 | [Microsoft Prov-GigaPath](https://hf.co/prov-gigapath/prov-gigapath) | -| **Hibou** | `hibou-b/l` | 768 / 1024 | [HistAI Hibou](https://hf.co/histai) | -| **Phikon** | `phikon` | 768 | [Owkin Phikon](https://hf.co/owkin/phikon) | -| **PLIP** | `plip` | 512 | [Huang et al. (2023)](https://hf.co/vinid/plip) | +### ZINB Loss -> [!NOTE] -> GigaPath and Hibou are **gated models**. You must accept the terms of use on their respective HuggingFace model pages before the code can download the weights. +The Zero-Inflated Negative Binomial (ZINB) loss is designed for raw, highly dispersed count data. It models the data using three parameters: ---- +- **$\pi$ (pi)**: Probability of zero-inflation (technical dropout). +- **$\mu$ (mu)**: Mean of the negative binomial distribution. +- **$\theta$ (theta)**: Inverse dispersion (clumping) parameter. -## 6. Biologically-Informed Bottleneck +The model outputs these parameters, and the loss computes the negative log-likelihood of the ground truth counts given this distribution. -The final gene expression $\hat{y}$ is a linear combination of pathway activations $s_k$: -$$\hat{y}_g = \sum_{k=1}^K s_k \cdot W_{gk} + b_g$$ +### Auxiliary Pathway Loss -### MSigDB Hallmarks Initialization +To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership. -When `--pathway-init` is enabled, $\mathbf{W}_{recon}$ is initialized from the **MSigDB Hallmark gene sets** (50 curated biological pathways). A binary membership matrix $\mathbf{M} \in \{0, 1\}^{P \times G}$ is constructed: -$$M_{kg} = \begin{cases} 1 & \text{if gene } g \in \text{pathway } k \\ 0 & \text{otherwise} \end{cases}$$ +The total objective becomes: +$$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$ -The gene reconstructor weight is then initialized as $\mathbf{W}_{recon} \leftarrow \mathbf{M}^\top \in \mathbb{R}^{G \times P}$. This gives the model a biologically-grounded starting point where each pathway token is connected only to its known member genes. +The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE. ---- - -## 7. Loss Functions and Gene Imbalance - -### The Gene Imbalance Problem +The full training objective with pathway sparsity regularisation: +$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$ -Gene expression follows a **heavy-tailed distribution**: HOUSEKEEPING genes dominate, while signalling genes are rare. With standard MSE, high-expression genes dominate the loss. +--- -The `--log-transform` flag (`log1p`) is the primary mitigation. Additionally, the following loss objectives are supported: +## 5. CLI Flags (Model Configuration) -| CLI Flag | Objective | Focus | +| Flag | Default | Description | | :--- | :--- | :--- | -| `mse` | Mean Squared Error | Magnitude accuracy at each spot. | -| `pcc` | Pearson Correlation | Spatial pattern coherence per gene. | -| `mse_pcc` | Combined Loss | Balances absolute magnitude and spatial shape. | - -The complete training objective with pathway sparsity is: -$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|\mathbf{W}_{recon}\|_1$$ +| `--model interaction` | — | Select SpatialTranscriptFormer | +| `--backbone` | `resnet50` | Foundation model backbone | +| `--token-dim` | 256 | Transformer embedding dimension | +| `--n-heads` | 4 | Number of attention heads | +| `--n-layers` | 2 | Transformer layers (minimum 2) | +| `--num-pathways` | 50 | Number of pathway bottleneck tokens | +| `--pathway-init` | off | Initialize gene_reconstructor from MSigDB | +| `--sparsity-lambda` | 0.0 | L1 regularisation on reconstruction weights | +| `--loss mse_pcc` | `mse` | Loss function (`mse`, `pcc`, `mse_pcc`, `zinb`) | +| `--pcc-weight` | 1.0 | Weight for PCC term in composite loss | +| `--pathway-loss-weight` | 0.0 | Weight for auxiliary pathway loss ($\lambda_{aux}$) | +| `--interactions` | `all` | Enabled interaction quadrants (`p2p p2h h2p h2h`) | +| `--log-transform` | off | Apply log1p to targets | +| `--return-attention` | off | Return attention maps from forward pass (for diagnostics) | +| `--n-neighbors` | 0 | Number of context neighbors (for hybrid/GNN models) | diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 62aa736..3b0d9c0 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -126,18 +126,42 @@ python -m spatial_transcript_former.train \ > **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download. -### Advanced: Multimodal Masking (Ablation) +### Robust Counting: ZINB + Auxiliary Loss -Test the model's robustness by masking specific quadrants (e.g., top-left and bottom-right). +For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended. ```bash python -m spatial_transcript_former.train \ --data-dir A:\hest_data \ --model interaction \ - --masked-quadrants top_left bottom_right \ - --precomputed + --backbone ctranspath \ + --pathway-init \ + --loss zinb \ + --pathway-loss-weight 0.5 \ + --lr 5e-5 \ + --batch-size 4 \ + --whole-slide \ + --precomputed \ + --epochs 200 ``` +### Choosing Interaction Modes + +By default, the model runs in **Full Interaction** mode (`p2p p2h h2p h2h`) where all token types attend to each other. You can selectively disable interactions using the `--interactions` flag for ablation or to enforce specific architectural constraints. + +For example, to use the **Pathway Bottleneck** (blocking patch-to-patch attention for interpretability): + +```bash +python -m spatial_transcript_former.train \ + --data-dir A:\hest_data \ + --model interaction \ + --interactions p2p p2h h2p \ + --precomputed \ + --whole-slide +``` + +Available interaction tokens: `p2p`, `p2h`, `h2p`, `h2h`. Default is all four (Full Interaction). + --- ## 4. HPC Batch Experiments @@ -145,7 +169,7 @@ python -m spatial_transcript_former.train \ The `hpc/array_train.slurm` script runs all three whole-slide experiments as a SLURM array job: | Index | Model | Supervision | Key Flags | -|:------|:------|:------------|:----------| +| :--- | :--- | :--- | :--- | | 0 | SpatialTranscriptFormer | Dense | `--whole-slide` | | 1 | AttentionMIL | Weak | `--whole-slide --weak-supervision` | | 2 | TransMIL | Weak | `--whole-slide --weak-supervision` | @@ -173,7 +197,7 @@ This produces a sorted comparison table and `comparison.csv`. Each training run automatically produces: | File | Description | -|:-----|:------------| +| :--- | :--- | | `training_log.csv` | Per-epoch metrics (train_loss, val_loss, attn_correlation) | | `results_summary.json` | Full config + final metrics + runtime | | `best_model_.pth` | Best checkpoint (by val loss) | @@ -196,10 +220,13 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime | `--weak-supervision` | Bag-level training for MIL models. | Use with `attention_mil` or `transmil`. | | `--pathway-init` | Initialize gene_reconstructor from MSigDB Hallmarks. | Use with `interaction` model. | | `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. | +| `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. | +| `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. | +| `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). | | `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. | | `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. | | `--mask-radius` | Euclidean distance for spatial attention gating. | Usually between 200 and 800. | -| `--n-neighbors` | Number of context neighbors to load. | Set `> 0` for models using spatial context. | +| `--n-neighbors` | Number of context neighbors to load. | Set `> 0` for hybrid/GNN models. | | `--use-amp` | Mixed precision training. | Recommended on modern GPUs. | | `--grad-accum-steps` | Gradient accumulation steps. | Use when memory is limited. | | `--compile` | Use `torch.compile` for speed. | Recommended on Linux/A100. | diff --git a/pyproject.toml b/pyproject.toml index 62f7a4f..e4064ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,3 +44,13 @@ stf-build-vocab = "spatial_transcript_former.data.build_vocab:main" [tool.setuptools.packages.find] where = ["src"] include = ["spatial_transcript_former*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore:FigureCanvasAgg is non-interactive:UserWarning", + "ignore:os.fork:RuntimeWarning", + "ignore:torch.utils._pytree._register_pytree_node is deprecated:UserWarning", +] diff --git a/scripts/diagnose_collapse.py b/scripts/diagnose_collapse.py new file mode 100644 index 0000000..d764933 --- /dev/null +++ b/scripts/diagnose_collapse.py @@ -0,0 +1,414 @@ +""" +Diagnostic script for detecting model collapse in SpatialTranscriptFormer. + +Loads a checkpoint and a sample, then checks: +1. Pathway token diversity (are all pathway embeddings collapsing to the same vector?) +2. Prediction variance (are all output genes/patches getting the same value?) +3. ZINB parameter health (are pi/mu/theta in sensible ranges?) +4. Gradient flow (are gradients reaching all layers?) +5. Attention entropy (are attention weights uniform or peaked?) + +Usage: + python scripts/diagnose_collapse.py --checkpoint runs/stf_interaction_zinb/best_model_interaction.pth --data-dir A:/hest_data +""" + +import argparse +import os +import sys +import json +import torch +import torch.nn.functional as F +import numpy as np + +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer +from spatial_transcript_former.data.dataset import ( + HEST_FeatureDataset, + load_global_genes, +) +from spatial_transcript_former.data.pathways import get_pathway_init, MSIGDB_URLS + + +def load_model( + checkpoint_path, device, num_genes=1000, backbone="ctranspath", loss="zinb" +): + """Load model from checkpoint.""" + gene_list_path = "global_genes.json" + if not os.path.exists(gene_list_path): + gene_list_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "global_genes.json", + ) + with open(gene_list_path) as f: + gene_list = json.load(f) + + urls = [MSIGDB_URLS["hallmarks"]] + pathway_init, pathway_names = get_pathway_init(gene_list[:num_genes], gmt_urls=urls) + num_pathways = len(pathway_names) + + model = SpatialTranscriptFormer( + num_genes=num_genes, + backbone_name=backbone, + pretrained=False, + num_pathways=num_pathways, + pathway_init=pathway_init, + output_mode="zinb" if loss == "zinb" else "counts", + ) + + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + # Handle torch.compile prefix + new_state_dict = {} + for k, v in state_dict.items(): + key = k.replace("_orig_mod.", "") + new_state_dict[key] = v + + model.load_state_dict(new_state_dict, strict=False) + model.to(device) + return model, pathway_names + + +def check_pathway_diversity(model): + """Check if pathway tokens have collapsed to the same vector.""" + pw = model.pathway_tokens.data.squeeze(0) # (P, D) + + # Pairwise cosine similarity + pw_norm = F.normalize(pw, dim=-1) + sim_matrix = pw_norm @ pw_norm.T # (P, P) + + # Off-diagonal similarities + mask = ~torch.eye(sim_matrix.shape[0], dtype=torch.bool, device=sim_matrix.device) + off_diag = sim_matrix[mask] + + mean_sim = off_diag.mean().item() + max_sim = off_diag.max().item() + min_sim = off_diag.min().item() + + # Check variance of each pathway token + pw_std = pw.std(dim=-1) # (P,) + dead_pathways = (pw_std < 1e-6).sum().item() + + print("\n" + "=" * 60) + print("1. PATHWAY TOKEN DIVERSITY") + print("=" * 60) + print(f" Num pathways: {pw.shape[0]}, Token dim: {pw.shape[1]}") + print(f" Mean pairwise cosine sim: {mean_sim:.4f}") + print(f" Max pairwise cosine sim: {max_sim:.4f}") + print(f" Min pairwise cosine sim: {min_sim:.4f}") + print(f" Dead pathways (zero std): {dead_pathways}") + + if mean_sim > 0.9: + print(" ⚠️ WARNING: Pathway tokens are highly similar — possible collapse!") + elif mean_sim > 0.7: + print(" ⚠️ CAUTION: Pathway tokens are somewhat similar") + else: + print(" ✅ Pathway tokens appear diverse") + + return mean_sim + + +def check_gene_reconstructor(model): + """Check if gene_reconstructor weights are degenerate.""" + W = model.gene_reconstructor.weight.data # (G, P) + + col_std = W.std(dim=0) # Per-pathway variance across genes + row_std = W.std(dim=1) # Per-gene variance across pathways + + dead_pathways = (col_std < 1e-6).sum().item() + dead_genes = (row_std < 1e-6).sum().item() + + # Check sparsity (fraction of near-zero weights) + sparsity = (W.abs() < 1e-4).float().mean().item() + + print("\n" + "=" * 60) + print("2. GENE RECONSTRUCTOR WEIGHTS") + print("=" * 60) + print(f" Shape: {W.shape}") + print(f" Weight range: [{W.min().item():.4f}, {W.max().item():.4f}]") + print(f" Weight std: {W.std().item():.4f}") + print(f" Dead pathways (zero col std): {dead_pathways}/{W.shape[1]}") + print(f" Dead genes (zero row std): {dead_genes}/{W.shape[0]}") + print(f" Sparsity (|w| < 1e-4): {sparsity:.2%}") + + if dead_pathways > W.shape[1] * 0.5: + print(" ⚠️ WARNING: >50% of pathways are dead in gene_reconstructor!") + else: + print(" ✅ Gene reconstructor weights appear healthy") + + +def check_predictions(model, feats, coords, device): + """Run a forward pass and check prediction diversity.""" + model.eval() + with torch.no_grad(): + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + output = model( + feats_t, return_dense=True, rel_coords=coords_t, return_pathways=True + ) + + if isinstance(output, tuple): + gene_output = output[0] + pathway_scores = output[1] + else: + gene_output = output + pathway_scores = None + + print("\n" + "=" * 60) + print("3. PREDICTION DIVERSITY") + print("=" * 60) + + if isinstance(gene_output, tuple): + # ZINB output: (pi, mu, theta) + pi, mu, theta = gene_output + pi, mu, theta = pi.squeeze(0), mu.squeeze(0), theta.squeeze(0) + + print(f" ZINB mode detected") + print(f" mu range: [{mu.min().item():.4f}, {mu.max().item():.4f}]") + print(f" mu mean: {mu.mean().item():.4f}") + print(f" mu std: {mu.std().item():.4f}") + print( + f" pi range: [{pi.min().item():.4f}, {pi.max().item():.4f}] (dropout probability)" + ) + print(f" pi mean: {pi.mean().item():.4f}") + print( + f" theta range: [{theta.min().item():.4f}, {theta.max().item():.4f}] (dispersion)" + ) + print(f" theta mean: {theta.mean().item():.4f}") + + # Spatial variance of mu (do predictions vary across patches?) + mu_spatial_std = mu.std(dim=0).mean().item() # avg per-gene spatial std + print(f"\n Spatial std of mu (per-gene avg): {mu_spatial_std:.6f}") + + # Per-patch variance (do predictions vary across genes?) + mu_gene_std = mu.std(dim=1).mean().item() + print(f" Gene std of mu (per-patch avg): {mu_gene_std:.6f}") + + if mu_spatial_std < 1e-4: + print( + " ⚠️ WARNING: Almost zero spatial variation — model is predicting the same thing everywhere!" + ) + elif mu_spatial_std < 1e-2: + print(" ⚠️ CAUTION: Very low spatial variation") + else: + print(" ✅ Spatial variation appears present") + + if pi.mean().item() > 0.9: + print( + " ⚠️ WARNING: pi is very high — model thinks almost everything is a dropout!" + ) + elif pi.mean().item() < 0.01: + print(" ✅ pi is low — model is not over-relying on zero inflation") + else: + preds = gene_output.squeeze(0) # (N, G) + print( + f" Prediction range: [{preds.min().item():.4f}, {preds.max().item():.4f}]" + ) + spatial_std = preds.std(dim=0).mean().item() + print(f" Spatial std (per-gene avg): {spatial_std:.6f}") + + if spatial_std < 1e-4: + print(" ⚠️ WARNING: Model is predicting the same thing everywhere!") + + if pathway_scores is not None: + pw = pathway_scores.squeeze(0) # (N, P) + print(f"\n Pathway scores shape: {pw.shape}") + print( + f" Pathway scores range: [{pw.min().item():.4f}, {pw.max().item():.4f}]" + ) + print(f" Pathway scores std: {pw.std().item():.4f}") + + # Per-pathway spatial variance + pw_spatial_std = pw.std(dim=0) # (P,) + print( + f" Per-pathway spatial std: min={pw_spatial_std.min().item():.4f}, max={pw_spatial_std.max().item():.4f}, mean={pw_spatial_std.mean().item():.4f}" + ) + + dead_pw = (pw_spatial_std < 1e-6).sum().item() + print(f" Dead pathways (zero spatial var): {dead_pw}/{pw.shape[1]}") + + +def check_attention_interactions(model, feats, coords, device): + """Analyze attention maps to see how different token groups are interacting.""" + model.eval() + with torch.no_grad(): + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + # SpatialTranscriptFormer.forward returns (gene_expr, pathway_scores, attentions) + # when return_pathways=True and return_attention=True + results = model( + feats_t, + rel_coords=coords_t, + return_pathways=True, + return_attention=True, + ) + + # Unpack results: (expr, p_scores, attentions) + if len(results) == 3: + attentions = results[2] + else: + # Fallback for unexpected return signature (shouldn't happen with my changes) + print( + " ⚠️ Could not extract attention maps (unexpected return signature)" + ) + return + + p = model.num_pathways + s = feats.shape[0] # number of spots + + print("\n" + "=" * 60) + print("5. ATTENTION INTERACTIONS") + print("=" * 60) + print(f" Interactions enabled: {model.interactions}") + + for i, attn in enumerate(attentions): + # attn is (B, T, T) where T = P + S + attn = attn.squeeze(0) # (T, T) + + # Interaction quadrants + p2p_m = attn[:p, :p].mean().item() + p2h_m = attn[:p, p:].mean().item() + h2p_m = attn[p:, :p].mean().item() + h2h_m = attn[p:, p:].mean().item() + + # Diagnostic for sparse attention: off-diagonal h2h + h2h_off_diag = attn[p:, p:].clone() + h2h_off_diag.fill_diagonal_(0) + h2h_off_diag_mean = h2h_off_diag.mean().item() + + print(f" LAYER {i}:") + print(f" p2p (Path -> Path): {p2p_m:.6f}") + print(f" p2h (Path -> Spot): {p2h_m:.6f}") + print(f" h2p (Spot -> Path): {h2p_m:.6f}") + print( + f" h2h (Spot -> Spot): {h2h_m:.6f} (off-diag: {h2h_off_diag_mean:.6f})" + ) + + # Warnings for "dead" interaction types + if h2p_m < 1e-4 and "h2p" in model.interactions: + print( + f" ⚠️ CAUTION: Spots are barely attending to pathways in Layer {i}" + ) + if p2h_m < 1e-4 and "p2h" in model.interactions: + print( + f" ⚠️ CAUTION: Pathways are barely attending to spots in Layer {i}" + ) + if h2h_off_diag_mean < 1e-7 and "h2h" in model.interactions: + print( + f" ⚠️ WARNING: Spot-to-spot attention is nearly zero despite being enabled!" + ) + + +def check_gradient_flow(model, feats, coords, device): + """Check if gradients reach all parts of the model.""" + model.train() + model.zero_grad() + + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + output = model(feats_t, return_dense=True, rel_coords=coords_t) + + if isinstance(output, tuple): + loss = output[1].sum() # mu + else: + loss = output.sum() + + loss.backward() + + print("\n" + "=" * 60) + print("4. GRADIENT FLOW") + print("=" * 60) + + layers = { + "pathway_tokens": model.pathway_tokens, + "image_proj.weight": model.image_proj.weight, + "gene_reconstructor.weight": model.gene_reconstructor.weight, + "fusion_engine.layers[0].self_attn.in_proj_weight": model.fusion_engine.layers[ + 0 + ].self_attn.in_proj_weight, + } + + if hasattr(model, "pi_reconstructor"): + layers["pi_reconstructor.weight"] = model.pi_reconstructor.weight + layers["theta_reconstructor.weight"] = model.theta_reconstructor.weight + + for name, param in layers.items(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + grad_max = param.grad.abs().max().item() + print(f" {name:45s} grad_norm={grad_norm:.6f} grad_max={grad_max:.6f}") + if grad_norm < 1e-10: + print(f" ⚠️ DEAD gradient in {name}!") + else: + print(f" {name:45s} ⚠️ NO GRADIENT!") + + +def main(): + parser = argparse.ArgumentParser(description="Diagnose model collapse") + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--data-dir", type=str, required=True) + parser.add_argument("--num-genes", type=int, default=1000) + parser.add_argument("--backbone", type=str, default="ctranspath") + parser.add_argument("--loss", type=str, default="zinb") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Load model + print("Loading model...") + model, pathway_names = load_model( + args.checkpoint, device, args.num_genes, args.backbone, args.loss + ) + print(f"Loaded {len(pathway_names)} pathways") + + # Load a sample + print("Loading sample data...") + common_gene_names = load_global_genes(args.data_dir, args.num_genes) + + patches_dir = os.path.join(args.data_dir, "patches") + if not os.path.isdir(patches_dir): + patches_dir = args.data_dir + st_dir = os.path.join(args.data_dir, "st") + + feat_dir = os.path.join(args.data_dir, f"patches/he_features_{args.backbone}") + if not os.path.isdir(feat_dir): + feat_dir = os.path.join(args.data_dir, f"he_features_{args.backbone}") + + # Find first available sample + sample_files = [f for f in os.listdir(feat_dir) if f.endswith(".pt")] + if not sample_files: + print("ERROR: No feature files found!") + sys.exit(1) + + sample_id = sample_files[0].replace(".pt", "") + feature_path = os.path.join(feat_dir, sample_files[0]) + h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") + + print(f"Using sample: {sample_id}") + + ds = HEST_FeatureDataset( + feature_path, + h5ad_path, + num_genes=args.num_genes, + selected_gene_names=common_gene_names, + whole_slide_mode=True, + ) + + feats, targets, coords = ds[0] + print(f"Features shape: {feats.shape}, Coords shape: {coords.shape}") + + # Run diagnostics + check_pathway_diversity(model) + check_gene_reconstructor(model) + check_predictions(model, feats, coords, device) + check_attention_interactions(model, feats, coords, device) + check_gradient_flow(model, feats, coords, device) + + print("\n" + "=" * 60) + print("DIAGNOSIS COMPLETE") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/monitor.py b/scripts/monitor.py index e9c959f..a3270ce 100644 --- a/scripts/monitor.py +++ b/scripts/monitor.py @@ -90,7 +90,7 @@ def serve_image(filename): ), html.Div( [ - # Graph for Losses + # Row 1: Losses + Correlation html.Div( [dcc.Graph(id="live-loss-graph", animate=False)], style={ @@ -99,9 +99,25 @@ def serve_image(filename): "verticalAlign": "top", }, ), - # Graph for Metrics html.Div( - [dcc.Graph(id="live-metric-graph", animate=False)], + [dcc.Graph(id="live-pcc-graph", animate=False)], + style={ + "width": "48%", + "display": "inline-block", + "verticalAlign": "top", + }, + ), + # Row 2: Variance + Learning Rate + html.Div( + [dcc.Graph(id="live-variance-graph", animate=False)], + style={ + "width": "48%", + "display": "inline-block", + "verticalAlign": "top", + }, + ), + html.Div( + [dcc.Graph(id="live-lr-graph", animate=False)], style={ "width": "48%", "display": "inline-block", @@ -157,80 +173,97 @@ def update_image(n): return html.Img(src=url, style={"maxWidth": "100%", "height": "auto"}) +def _make_traces(df, cols, smoothing_window): + """Create Plotly traces for the given columns with optional smoothing.""" + traces = [] + for col in cols: + if col not in df.columns: + continue + y_data = df[col].dropna() + epochs = df.loc[y_data.index, "epoch"] + if smoothing_window and smoothing_window > 1: + y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + traces.append(go.Scatter(x=epochs, y=y_data, mode="lines", name=col)) + return traces + + @app.callback( [ Output("live-loss-graph", "figure"), - Output("live-metric-graph", "figure"), + Output("live-pcc-graph", "figure"), + Output("live-variance-graph", "figure"), + Output("live-lr-graph", "figure"), Output("last-updated", "children"), ], [Input("interval-component", "n_intervals"), Input("smoothing-slider", "value")], ) def update_graphs(n, smoothing_window): + empty = dash.no_update if not os.path.exists(log_path): - return ( - dash.no_update, - dash.no_update, - "Waiting for training_log.csv to be created...", - ) + return empty, empty, empty, empty, "Waiting for training_log.csv..." try: df = pd.read_csv(log_path) except Exception as e: - return dash.no_update, dash.no_update, f"Error reading log file: {str(e)}" + return empty, empty, empty, empty, f"Error reading log: {e}" if df.empty or "epoch" not in df.columns: - return ( - dash.no_update, - dash.no_update, - "Log file is empty or missing 'epoch' column.", - ) + return empty, empty, empty, empty, "Log empty or missing 'epoch'." - # Subplot 1: Losses - loss_cols = [c for c in df.columns if "loss" in c.lower()] - loss_traces = [] - for col in loss_cols: - y_data = df[col] - if smoothing_window and smoothing_window > 1: - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + margin = dict(l=40, r=40, t=40, b=40) - loss_traces.append(go.Scatter(x=df["epoch"], y=y_data, mode="lines", name=col)) - - loss_layout = go.Layout( - title="Training vs Validation Loss", - xaxis=dict(title="Epoch"), - yaxis=dict( - title="Loss", type="log" - ), # Log scale often helps with early MSE spikes - margin=dict(l=40, r=40, t=40, b=40), - ) + # Chart 1: Losses (log scale) + loss_cols = [c for c in df.columns if "loss" in c.lower()] + loss_fig = { + "data": _make_traces(df, loss_cols, smoothing_window), + "layout": go.Layout( + title="Loss", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Loss", type="log"), + margin=margin, + ), + } - # Subplot 2: Interpretability Metrics - metric_cols = [c for c in df.columns if c not in loss_cols and c != "epoch"] - metric_traces = [] - for col in metric_cols: - y_data = df[col] - if smoothing_window and smoothing_window > 1: - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + # Chart 2: Correlation (PCC, MAE) + corr_cols = [c for c in ["val_pcc", "val_mae"] if c in df.columns] + pcc_fig = { + "data": _make_traces(df, corr_cols, smoothing_window), + "layout": go.Layout( + title="Correlation & Error", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Score"), + margin=margin, + ), + } - metric_traces.append( - go.Scatter(x=df["epoch"], y=y_data, mode="lines", name=col) - ) + # Chart 3: Prediction Variance + var_cols = [c for c in ["pred_variance"] if c in df.columns] + var_fig = { + "data": _make_traces(df, var_cols, smoothing_window), + "layout": go.Layout( + title="Prediction Variance (collapse detector)", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Variance", type="log"), + margin=margin, + ), + } - metric_layout = go.Layout( - title="Interpretability Metrics (MAE, PCC, Correlation)", - xaxis=dict(title="Epoch"), - yaxis=dict(title="Score / Error"), - margin=dict(l=40, r=40, t=40, b=40), - ) + # Chart 4: Learning Rate + lr_cols = [c for c in ["lr"] if c in df.columns] + lr_fig = { + "data": _make_traces(df, lr_cols, smoothing_window), + "layout": go.Layout( + title="Learning Rate Schedule", + xaxis=dict(title="Epoch"), + yaxis=dict(title="LR", type="log"), + margin=margin, + ), + } last_epoch = df["epoch"].iloc[-1] update_text = f"Last updated: Epoch {last_epoch} (Polled automatically)" - return ( - {"data": loss_traces, "layout": loss_layout}, - {"data": metric_traces, "layout": metric_layout}, - update_text, - ) + return loss_fig, pcc_fig, var_fig, lr_fig, update_text @app.callback( diff --git a/scripts/run_preset.py b/scripts/run_preset.py index 3e60dfa..8e51d2b 100644 --- a/scripts/run_preset.py +++ b/scripts/run_preset.py @@ -3,7 +3,26 @@ import sys import os +from spatial_transcript_former.config import get_config + +# Common flags for all STF interaction models +STF_COMMON = [ + "--model", + "interaction", + "--backbone", + "ctranspath", + "--precomputed", + "--whole-slide", + "--pathway-init", + "--use-amp", + "--log-transform", + "--loss", + "mse_pcc", + "--resume", +] + PRESETS = { + # --- Baselines --- "he2rna_baseline": [ "--model", "he2rna", @@ -36,58 +55,42 @@ "--batch-size", "1", ], - "stf_pathway_nystrom": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--use-nystrom", - "--pathway-init", - "--sparsity-lambda", - "0.05", - "--lr", - "1e-4", + # --- Interaction Models (Layer Scaling) --- + "stf_interaction_l2": STF_COMMON + + [ + "--n-layers", + "2", + "--token-dim", + "256", + "--n-heads", + "4", "--batch-size", - "8", - "--epochs", - "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse_pcc", - "--resume", + "4", ], - "stf_pathway": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", - "--pathways", - "KEGG_COLORECTAL_CANCER", - "GRADE_COLON_CANCER_UP", - "GRADE_COLON_CANCER_DN", - "GRADE_COLON_AND_RECTAL_CANCER_UP", - "GRADE_COLON_AND_RECTAL_CANCER_DN", - "--sparsity-lambda", - "0.01", - "--lr", - "1e-5", + "stf_interaction_l4": STF_COMMON + + [ + "--n-layers", + "4", + "--token-dim", + "384", + "--n-heads", + "8", "--batch-size", + "4", + ], + "stf_interaction_l6": STF_COMMON + + [ + "--n-layers", + "6", + "--token-dim", + "512", + "--n-heads", "8", - "--epochs", - "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse_pcc", - "--resume", + "--batch-size", + "2", # Reduced batch size for large model memory ], - "stf_pathway_hybrid": [ + # --- Specific Configurations --- + "stf_interaction_zinb": [ "--model", "interaction", "--backbone", @@ -95,54 +98,49 @@ "--precomputed", "--whole-slide", "--pathway-init", - "--pathways", - "KEGG_COLORECTAL_CANCER", - "GRADE_COLON_CANCER_UP", - "GRADE_COLON_CANCER_DN", - "GRADE_COLON_AND_RECTAL_CANCER_UP", - "GRADE_COLON_AND_RECTAL_CANCER_DN", "--sparsity-lambda", - "0.01", + "0", "--lr", "1e-4", "--batch-size", - "8", + "4", "--epochs", "2000", - "--log-transform", "--use-amp", "--loss", - "mse_pcc", + "zinb", + "--log-transform", + "--pathway-loss-weight", + "0.5", + "--interactions", + "p2p", + "p2h", + "h2p", + "h2h", + "--plot-pathways", "--resume", ], - "stf_pathway_gnn": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", + # Legacy / Specialized + "stf_pathway_nystrom": STF_COMMON + + [ "--sparsity-lambda", - "0.01", + "0.05", "--lr", "1e-4", "--batch-size", "8", "--epochs", "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse", - "--resume", - "--early-mixer", - "none", - "--late-refiner", - "gnn", ], } +# Add alias for the one currently running to ensure backward compatibility for монитор +PRESETS["stf_interaction_mse_pcc"] = PRESETS["stf_interaction_l2"] + [ + "--pathway-loss-weight", + "0.5", + "--plot-pathways", +] + def main(): parser = argparse.ArgumentParser( @@ -161,7 +159,7 @@ def main(): default=get_config("data_dirs", ["A:\\hest_data"])[0], help="Data directory", ) - parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") + parser.add_argument("--epochs", type=int, default=2000, help="Number of epochs") parser.add_argument("--max-samples", type=int, default=None, help="Limit samples") parser.add_argument( "--output-dir", type=str, default=None, help="Override output dir" diff --git a/setup.ps1 b/setup.ps1 index 896303f..9c1bd04 100644 --- a/setup.ps1 +++ b/setup.ps1 @@ -7,8 +7,8 @@ $EnvName = "SpatialTranscriptFormer" # Check if conda environment exists $CondaEnv = conda env list | Select-String $EnvName if ($null -eq $CondaEnv) { - Write-Host "Creating conda environment '$EnvName' with Python 3.10..." -ForegroundColor Yellow - conda create -n $EnvName python=3.10 -y + Write-Host "Creating conda environment '$EnvName' with Python 3.9..." -ForegroundColor Yellow + conda create -n $EnvName python=3.9 -y } else { Write-Host "Conda environment '$EnvName' already exists." -ForegroundColor Green @@ -22,5 +22,6 @@ Write-Host "Setup Complete!" -ForegroundColor Green Write-Host "You can now use the following commands:" Write-Host " stf-download --help" Write-Host " stf-split --help" +Write-Host " stf-build-vocab --help" Write-Host "" Write-Host "To run tests, use: .\test.ps1" diff --git a/setup.sh b/setup.sh index 4a18e26..a16c23e 100644 --- a/setup.sh +++ b/setup.sh @@ -7,8 +7,8 @@ ENV_NAME="SpatialTranscriptFormer" # Check if conda environment exists if ! conda env list | grep -q "$ENV_NAME"; then - echo "Creating conda environment '$ENV_NAME' with Python 3.10..." - conda create -n $ENV_NAME python=3.10 -y + echo "Creating conda environment '$ENV_NAME' with Python 3.9..." + conda create -n $ENV_NAME python=3.9 -y else echo "Conda environment '$ENV_NAME' already exists." fi @@ -20,7 +20,7 @@ echo "" echo "Setup Complete!" echo "You can now use the following commands (after activating the environment):" echo " stf-download --help" -echo " stf-download-bowel" echo " stf-split --help" +echo " stf-build-vocab --help" echo "" echo "To run tests, use: ./test.sh" diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 34d50c4..695cc77 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -2,14 +2,9 @@ Pathway-histology interaction layers for SpatialTranscriptFormer. This module defines the building blocks that fuse learnable pathway tokens -with histology patch features via self- and cross-attention: - -* ``PathwayTokenizer`` — learnable pathway embeddings -* ``LocalPatchMixer`` — scatter-gather depthwise conv for neighbourhood mixing -* ``SinusoidalPositionalEncoder`` — 2-D sinusoidal positional embeddings -* ``EarlyFusionBlock`` — concat-attention-slice wrapper (Jaume-mode) -* ``MultimodalFusion`` — standard Transformer early fusion -* ``NystromEncoder`` — Nyström-based early fusion +with histology patch features via self-attention: + +* ``LearnedSpatialEncoder`` — 2-D learned positional embeddings * ``SpatialTranscriptFormer`` — the full model """ @@ -19,449 +14,53 @@ from .backbones import get_backbone -class PathwayTokenizer(nn.Module): - """Projects pathway indices or data into latent embeddings. - - This module maintains a set of learnable embeddings of size dim for each - of the num_pathways defined. - - Attributes: - num_pathways (int): Number of pathway tokens to generate. - dim (int): Dimension of each pathway token. - """ - - def __init__(self, num_pathways, dim): - """Initializes the PathwayTokenizer. - - Args: - num_pathways (int): Total number of pathways. - dim (int): Token embedding dimension. - """ - super().__init__() - self.num_pathways = num_pathways - self.dim = dim - self.pathway_embeddings = nn.Parameter(torch.randn(1, num_pathways, dim)) - - def forward(self, batch_size): - """Generates pathway tokens for the batch. - - Args: - batch_size (int): Number of samples in the batch. - - Returns: - torch.Tensor: Pathway tokens of shape (batch_size, num_pathways, dim). - """ - return self.pathway_embeddings.expand(batch_size, -1, -1) - - -class LocalPatchMixer(nn.Module): - """Mixes each patch's features with its immediate spatial neighbours via a depthwise conv. - - The forward pass performs three steps: +class LearnedSpatialEncoder(nn.Module): + """Encodes 2D spatial coordinates via a small learned MLP. - 1. **Scatter** — place each patch feature vector into a 2-D dense grid at its - grid-coordinate position. Coordinates are zero-based and must be - integer-like (grid indices, *not* raw pixel coordinates). - 2. **Depthwise Conv + GELU** — apply a ``kernel_size x kernel_size`` grouped - convolution over the spatial grid. Each channel is processed independently - so the parameter count scales with *dim*, not *dim²*. - 3. **Gather + Residual** — read the convolved values back at the same grid - positions and add them to the original features (residual connection). - - A safety guard skips mixing if the bounding-box of grid coordinates exceeds - 256×256 cells, which would indicate malformed / non-grid coordinates and - could allocate prohibitive memory. + Unlike sinusoidal PE, this produces smooth, non-periodic embeddings + that vary gradually across the tissue. Coordinates are normalised to + [-1, 1] per-batch before encoding for scale invariance. """ - def __init__(self, dim, kernel_size=3): + def __init__(self, dim): super().__init__() - self.dim = dim - self.conv = nn.Conv2d( - dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim + self.mlp = nn.Sequential( + nn.Linear(2, dim), + nn.GELU(), + nn.Linear(dim, dim), ) - self.act = nn.GELU() - - def forward(self, x, coords): - """ - Args: - x: (B, N, D) Patch features - coords: (B, N, 2) Coordinates (absolute-ish indices) - - Returns: - (B, N, D) Enriched features - """ - B, N, D = x.shape - device = x.device - - # 1. Normalize coords to grid indices - # We assume coords are already roughly grid-like or pixel coords. - # But to be safe, we subtract min and assume unit step? - # Actually, let's assume coords ARE grid indices for now, or scaled pixels. - # If they are large pixels (e.g. 10000), we need to know the patch size (224). - # We'll trust the caller to pass grid-indices. - - # Find bounds per batch? Or global max? - # To batch this efficiently, we find global bounds in the batch - min_c = coords.min(dim=1, keepdim=True)[0] # (B, 1, 2) - grid_coords = coords - min_c # Zero-based - - max_c = grid_coords.max(dim=1)[0].max(dim=0)[0] # (2,) - H, W = int(max_c[1].item()) + 1, int(max_c[0].item()) + 1 - - # Cap memory usage if outliers exist - if H * W > 256 * 256: - # Fallback: Don't crash on massive outlier. Just skip mix? - # Or clamp? Let's skip mixing if grid is absurd. - return x - - # 2. Scatter - grid = torch.zeros(B, D, H, W, device=device, dtype=x.dtype) - - b_idx = torch.arange(B, device=device).view(B, 1, 1) - d_idx = torch.arange(D, device=device).view(1, D, 1) - - gx = grid_coords[..., 0].long() # (B, N) - gy = grid_coords[..., 1].long() # (B, N) - - gy_idx = gy.unsqueeze(1) # (B, 1, N) - gx_idx = gx.unsqueeze(1) # (B, 1, N) - - x_T = x.transpose(1, 2) # (B, D, N) - grid[b_idx, d_idx, gy_idx, gx_idx] = x_T - - # 3. Conv - out_grid = self.act(self.conv(grid)) - - # 4. Gather & Residual - out = out_grid[b_idx, d_idx, gy_idx, gx_idx].transpose( - 1, 2 - ) # (B, D, N) -> (B, N, D) - - return x + out - - -class GraphPatchMixer(nn.Module): - """Mixes each patch's features with its k-nearest physical neighbors using Graph Attention. - - This acts as a spatial refiner. It constructs a k-NN graph on the fly from the - physical coordinates and performs message passing. - """ - - def __init__(self, dim, k=8, heads=4): - super().__init__() - self.dim = dim - self.k = k - self.heads = heads - self.head_dim = dim // heads - - assert self.head_dim * heads == dim, "dim must be divisible by heads" - - # GAT-style linear projections - self.to_qkv = nn.Linear(dim, dim * 3, bias=False) - self.proj = nn.Linear(dim, dim) - self.act = nn.GELU() - - def forward(self, x, coords, mask=None): - """ - Args: - x: (B, N, D) Patch features - coords: (B, N, 2) Coordinates (absolute physical positions) - mask: (B, N) Boolean padding mask (True = padding) - - Returns: - (B, N, D) Refined features - """ - B, N, D = x.shape - device = x.device - - # 1. Build k-NN graph - # Compute pairwise distances (B, N, N) - dist = torch.cdist(coords, coords) - - k_actual = min(self.k + 1, N) # +1 for self-loop, bounded by N - dist_for_topk = dist.clone() - if mask is not None: - dist_for_topk.masked_fill_(mask.unsqueeze(1).expand(B, N, N), float("inf")) - - _, nn_idx = torch.topk(-dist_for_topk, k=k_actual, dim=-1) # (B, N, K) - - # 2. Extract neighbor features - batch_indices = torch.arange(B, device=device).view(-1, 1, 1) # (B, 1, 1) - x_neighbors = x[batch_indices, nn_idx, :] # (B, N, K, D) - - # 3. Message Passing (GAT-style) - # Query comes from the center node, Key/Value come from neighbors - qkv_center = self.to_qkv(x) # (B, N, 3D) - q_center, _, _ = qkv_center.chunk(3, dim=-1) # (B, N, D) - - qkv_neighbors = self.to_qkv(x_neighbors) - _, k_neighbors, v_neighbors = qkv_neighbors.chunk(3, dim=-1) # (B, N, K, D) - - # Reshape for multi-head attention - q = q_center.view(B, N, self.heads, self.head_dim) # (B, N, H, d) - k = k_neighbors.view( - B, N, k_actual, self.heads, self.head_dim - ) # (B, N, K, H, d) - v = v_neighbors.view( - B, N, k_actual, self.heads, self.head_dim - ) # (B, N, K, H, d) - - # Attention scores: Q * K^T - q = q.unsqueeze(3) # (B, N, H, 1, d) - k = k.permute(0, 1, 3, 2, 4) # (B, N, H, K, d) - - # dot product over 'd' - attn = (q * k).sum(dim=-1) / (self.head_dim**0.5) # (B, N, H, K) - - if mask is not None: - batch_indices = torch.arange(B, device=device).view(-1, 1, 1) # (B, 1, 1) - neighbor_is_padded = mask[batch_indices, nn_idx] # (B, N, K) - neighbor_is_padded = neighbor_is_padded.unsqueeze(2).expand( - -1, -1, self.heads, -1 - ) - attn = attn.masked_fill(neighbor_is_padded, float("-inf")) - - attn = F.softmax(attn, dim=-1) # (B, N, H, K) - if mask is not None: - attn = torch.nan_to_num(attn, nan=0.0) - - # Weighted sum of values - v = v.permute(0, 1, 3, 2, 4) # (B, N, H, K, d) - out = (attn.unsqueeze(-1) * v).sum(dim=-2) # (B, N, H, d) - - # Reshape back to D - out = out.reshape(B, N, D) - out = self.proj(self.act(out)) - - # 4. Residual Connection - return x + out - -class SinusoidalPositionalEncoder(nn.Module): - """Encodes 2D spatial coordinates into sinusoidal embeddings. - - - Based on the "Attention is All You Need" 1D PE, extended to 2D. - Each spatial dimension (x, y) is encoded separately with dim/2 channels. - """ - - def __init__(self, dim, temperature=10000): - super().__init__() - self.dim = dim - self.temperature = temperature + def _normalize_coords(self, coords): + """Normalize coordinates to [-1, 1] range per batch.""" + # Centre at zero + coords = coords - coords.mean(dim=1, keepdim=True) + # Scale to [-1, 1] + scale = coords.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1.0) + return coords / scale def forward(self, rel_coords): """ Args: - rel_coords (torch.Tensor): (B, S, 2) relative coordinates. + rel_coords (torch.Tensor): (B, S, 2) spatial coordinates. Returns: torch.Tensor: (B, S, dim) positional embeddings. """ - x = rel_coords[..., 0] - y = rel_coords[..., 1] - - # Split dim into two halves for x and y - dim_x = self.dim // 2 - dim_y = self.dim - dim_x - - # Create geometric progression of frequencies - # div_term = 1 / (temperature ** (2i / d_model)) - div_term_x = torch.exp( - torch.arange(0, dim_x, 2, dtype=torch.float32, device=rel_coords.device) - * -(torch.log(torch.tensor(self.temperature)) / dim_x) - ) - div_term_y = torch.exp( - torch.arange(0, dim_y, 2, dtype=torch.float32, device=rel_coords.device) - * -(torch.log(torch.tensor(self.temperature)) / dim_y) - ) - - pe_x = torch.zeros(*x.shape, dim_x, device=rel_coords.device) - pe_y = torch.zeros(*y.shape, dim_y, device=rel_coords.device) - - # Sin/Cos pairs for X - pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term_x) - pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term_x) - - # Sin/Cos pairs for Y - pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1) * div_term_y) - pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1) * div_term_y) - - # Concatenate X and Y embeddings -> (B, S, dim) - return torch.cat([pe_x, pe_y], dim=-1) - - -class EarlyFusionBlock(nn.Module): - """Generic wrapper for Early Fusion interaction. - - Handles the common logic for "concatenation -> attention -> slicing" used in - multimodal interaction (Jaume et al. mode). - """ - - def __init__(self, encoder, masked_quadrants=None): - """Initializes EarlyFusionBlock. - - Args: - encoder (nn.Module): The attention mechanism (Standard or Nystrom). - masked_quadrants (list, optional): Quadrants to mask. - """ - super().__init__() - self.encoder = encoder - self.masked_quadrants = masked_quadrants or [] - - def _generate_quadrant_mask(self, num_p, num_h, device): - """Generates Eq 1 quadrant mask to restrict attention between modalities. - - [ P2P P2H ] - [ H2P H2H ] - """ - n_total = num_p + num_h - mask = torch.zeros((n_total, n_total), device=device, dtype=torch.bool) - - p_slice = slice(0, num_p) - h_slice = slice(num_p, n_total) - - if "H2H" in self.masked_quadrants: - mask[h_slice, h_slice] = True - if "H2P" in self.masked_quadrants: - mask[h_slice, p_slice] = True - if "P2H" in self.masked_quadrants: - mask[p_slice, h_slice] = True - - return mask - - def forward(self, p_tokens, h_tokens, return_all_tokens=False): - """Performs multimodal fusion. - - Args: - p_tokens: Pathway tokens (B, Np, D) - h_tokens: Histology tokens (B, Nh, D) - return_all_tokens: If True, returns full concatenated sequence. - - Returns: - Contextualized tokens. - """ - np = p_tokens.shape[1] - nh = h_tokens.shape[1] - - x = torch.cat([p_tokens, h_tokens], dim=1) - - mask = None - if self.masked_quadrants: - mask = self._generate_quadrant_mask(np, nh, p_tokens.device) - - # Standard PyTorch MHA and NystromAttention use different mask signatures. - # We handle the dispatch here to ensure correct mask application. - out = self.encoder(x, mask=mask) - - if return_all_tokens: - return out - - return out[:, :np, :] - - -class MultimodalFusion(EarlyFusionBlock): - """Standard Transformer-based Early Fusion.""" - - def __init__(self, dim, n_heads, n_layers, dropout=0.1, masked_quadrants=None): - encoder_layer = nn.TransformerEncoderLayer( - d_model=dim, - nhead=n_heads, - dim_feedforward=dim * 4, - dropout=dropout, - batch_first=True, - ) - transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) - super().__init__(encoder=transformer, masked_quadrants=masked_quadrants) + return self.mlp(self._normalize_coords(rel_coords)) -class NystromStack(nn.Module): - """Sequential stack of Nyström attention layers, compatible with ``EarlyFusionBlock``. - - Each element of *layers* is a ``nn.ModuleList`` of - ``(norm1, NystromAttention, norm2, FeedForward)`` sub-modules following a - pre-norm residual layout. The attention mask convention differs from - standard PyTorch MHA: ``True`` means *keep* (attend), so the mask is - inverted (``~mask``) before being passed to ``NystromAttention``. - """ - - def __init__(self, layers): - super().__init__() - self.layers = nn.ModuleList(layers) - - def forward(self, x, mask=None): - attn_mask = ~mask if mask is not None else None - - for norm1, attn, norm2, ff in self.layers: - x = x + attn(norm1(x), mask=attn_mask) - x = x + ff(norm2(x)) - return x - - -class NystromEncoder(EarlyFusionBlock): - """Nystrom-based Early Fusion.""" - - def __init__( - self, - dim, - n_heads, - n_layers, - dropout=0.1, - num_landmarks=256, - masked_quadrants=None, - ): - from nystrom_attention import NystromAttention - - layers = [] - for _ in range(n_layers): - layers.append( - nn.ModuleList( - [ - nn.LayerNorm(dim), - NystromAttention( - dim=dim, - heads=n_heads, - dim_head=dim // n_heads, - num_landmarks=num_landmarks, - dropout=dropout, - ), - nn.LayerNorm(dim), - nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(dim * 4, dim), - nn.Dropout(dropout), - ), - ] - ) - ) - - stack = NystromStack(layers) - super().__init__(encoder=stack, masked_quadrants=masked_quadrants) - - -# Removed NystromDecoderLayer and NystromDecoder as we are focusing solely on the Jaume pipeline. +VALID_INTERACTIONS = {"p2p", "p2h", "h2p", "h2h"} class SpatialTranscriptFormer(nn.Module): """Transformer for predicting gene expression from histology and spatial context. Integrates histology feature extraction with pathway-based bottleneck - attention to predict gene transcript counts. Supports standard decoder-style - interaction or early fusion multimodal interaction. - - Architectural Inspiration: - Jaume et al. (2024). "Modeling Dense Multimodal Interactions Between - Biological Pathways and Histology for Survival Prediction" (SurvPath). - - Benchmark/Framework: - Jaume et al. (2024). "HEST-1k: A Dataset for Spatial Transcriptomics - and Histology Image Analysis." NeurIPS (Spotlight). + attention to predict gene transcript counts. Follows a standard Vision + Transformer architecture where pathway tokens act as [CLS]-like bottlenecks. Attributes: num_pathways (int): Number of pathway bottlenecks. - use_nystrom (bool): Whether to use efficient Nystrom attention. """ def __init__( @@ -470,19 +69,14 @@ def __init__( num_pathways=50, backbone_name="resnet50", pretrained=True, - token_dim=512, - n_heads=8, + token_dim=256, + n_heads=4, n_layers=2, dropout=0.1, - use_nystrom=False, - mask_radius=None, - masked_quadrants=None, - num_landmarks=256, pathway_init=None, use_spatial_pe=True, - early_mixer="conv", - late_refiner=None, - k_neighbors=8, + output_mode="counts", + interactions=None, ): """Initializes the SpatialTranscriptFormer. @@ -495,27 +89,42 @@ def __init__( n_heads (int): Number of attention heads. n_layers (int): Number of transformer/interaction layers. dropout (float): Dropout probability. - use_nystrom (bool): Enable linear-complexity attention. - mask_radius (float, optional): Distance-based masking threshold. - masked_quadrants (list, optional): Mask configuration for fusion. - num_landmarks (int): Landmarks for Nystrom attention. pathway_init (Tensor, optional): Biological pathway membership matrix of shape (P, G) to initialize gene_reconstructor. use_spatial_pe (bool): Incorporate relative gradients into attention. - early_mixer (str, optional): 'conv' or None. - late_refiner (str, optional): 'gnn' or None. - k_neighbors (int): k-NN for GNN refiner. + output_mode (str): 'counts' (standard MSE/PCC) or 'zinb' (Zero-Inflated Negative Binomial outputs). + interactions (list[str], optional): Which attention interactions to + enable. Valid keys are ``p2p``, ``p2h``, ``h2p``, ``h2h``. + Defaults to all four (full self-attention). """ super().__init__() + if interactions is None: + interactions = list(VALID_INTERACTIONS) + unknown = set(interactions) - VALID_INTERACTIONS + if unknown: + raise ValueError( + f"Unknown interaction keys: {unknown}. " + f"Valid keys are: {VALID_INTERACTIONS}" + ) + self.interactions = set(interactions) + + # Enforce minimum 2 layers when h2h is blocked. + # Layer 1 lets pathways gather patch info, Layer 2 lets patches + # read the now-informed pathway tokens. + if "h2h" not in self.interactions and n_layers < 2: + raise ValueError( + f"n_layers must be >= 2 when h2h is not in interactions. " + f"Got n_layers={n_layers}. Layer 1 lets pathways gather spatial info, " + f"Layer 2 lets patches read contextualized pathways." + ) + # Override num_pathways if biological init is provided if pathway_init is not None: num_pathways = pathway_init.shape[0] print(f"Pathway init: overriding num_pathways to {num_pathways}") self.num_pathways = num_pathways - self.use_nystrom = use_nystrom - self.mask_radius = mask_radius self.use_spatial_pe = use_spatial_pe # 1. Image Encoder Backbone @@ -523,50 +132,35 @@ def __init__( backbone_name, pretrained=pretrained ) - # 2. Image Projector & Spatial Modules + # 2. Image Projector self.image_proj = nn.Linear(self.image_feature_dim, token_dim) - self.early_mixer = None - if early_mixer == "conv": - self.early_mixer = LocalPatchMixer(token_dim) - - self.late_refiner = None - if late_refiner == "gnn": - self.late_refiner = GraphPatchMixer( - dim=token_dim, k=k_neighbors, heads=n_heads - ) - - self.pathway_tokenizer = PathwayTokenizer(num_pathways, token_dim) + # 3. Learnable pathway tokens (one per pathway, shared across batch) + self.pathway_tokens = nn.Parameter(torch.randn(1, num_pathways, token_dim)) - # 3b. Spatial Positional Encoder + # 4. Spatial Positional Encoder (optional) self.spatial_encoder = None if use_spatial_pe: - self.spatial_encoder = SinusoidalPositionalEncoder(token_dim) + self.spatial_encoder = LearnedSpatialEncoder(token_dim) - # 4. Interaction Engine (Unified Early Fusion Path) - if use_nystrom: - if masked_quadrants: - print( - "Warning: Nystrom attention does not support 2D quadrant masking. Mask will be ignored." - ) - self.fusion_engine = NystromEncoder( - token_dim, - n_heads, - n_layers, - dropout=dropout, - num_landmarks=num_landmarks, - masked_quadrants=None, - ) - else: - self.fusion_engine = MultimodalFusion( - token_dim, - n_heads, - n_layers, - dropout=dropout, - masked_quadrants=masked_quadrants, - ) + # 5. Interaction Engine (Standard Transformer) + encoder_layer = nn.TransformerEncoderLayer( + d_model=token_dim, + nhead=n_heads, + dim_feedforward=token_dim * 4, + dropout=dropout, + batch_first=True, + norm_first=True, + ) + + self.fusion_engine = nn.TransformerEncoder( + encoder_layer, num_layers=n_layers, enable_nested_tensor=False + ) + + # Learnable temperature for cosine similarity scoring + # Initialized to log(1/0.07) ≈ 2.66 following CLIP convention + self.log_temperature = nn.Parameter(torch.tensor(2.6593)) - self.pathway_activator = nn.Linear(token_dim, 1) self.gene_reconstructor = nn.Linear(num_pathways, num_genes) if pathway_init is not None: @@ -575,215 +169,212 @@ def __init__( # pathway_init is (num_pathways, num_genes) self.gene_reconstructor.weight.copy_(pathway_init.T) self.gene_reconstructor.bias.zero_() + # Expose the MSigDB matrix for AuxiliaryPathwayLoss + self._pathway_init_matrix = pathway_init.clone() print("Initialized gene_reconstructor with MSigDB Hallmarks") + else: + self._pathway_init_matrix = None - # Dense Head (Whole Slide Mode) - self.gene_head = nn.Linear(token_dim, num_genes) + self.output_mode = output_mode + if self.output_mode == "zinb": + # pi: probability of dropout (zero-inflation) + self.pi_reconstructor = nn.Linear(num_pathways, num_genes) + # theta: inverse dispersion + self.theta_reconstructor = nn.Linear(num_pathways, num_genes) - def get_sparsity_loss(self): - """Computes L1 norm of reconstruction weights for sparsity regularization. + # Initialize ZINB specialized heads carefully to avoid immediate collapse + with torch.no_grad(): + # Initialize pi aggressively negative so sigmoid(pi) is near 0.05 + nn.init.normal_(self.pi_reconstructor.weight, std=0.01) + nn.init.constant_(self.pi_reconstructor.bias, -3.0) - Returns: - torch.Tensor: L1 loss value. - """ - return torch.norm(self.gene_reconstructor.weight, p=1) + # Initialize theta so softplus(theta) is roughly 1.0 + nn.init.normal_(self.theta_reconstructor.weight, std=0.01) + nn.init.constant_(self.theta_reconstructor.bias, 0.5) - def _normalize_coords(self, coords): - """ - Infer grid scaling from absolute coordinates. - (e.g., convert 224-pixel steps into unit steps). + def _build_interaction_mask(self, p, s, device): + """Build ``(P+S, P+S)`` boolean attention mask from ``self.interactions``. - Args: - coords: (B, N, 2) absolute pixel coordinates. + In PyTorch transformer convention, ``True`` means **blocked**. Returns: - (B, N, 2) grid-aligned indices. + torch.Tensor or None: Mask tensor, or ``None`` when all + interactions are enabled (no masking needed). """ - if coords is None: - return None - - # Zero-base - min_c = coords.min(dim=1, keepdim=True)[0] - zero_c = coords - min_c - - # Infer step per batch - # We use the median of non-zero adjacent differences to find step size - B, N, _ = zero_c.shape - if N < 2: - return torch.zeros_like(zero_c) - - steps = [] - for i in [0, 1]: # X and Y - c_sorted, _ = torch.sort(zero_c[..., i], dim=1) - diffs = c_sorted[:, 1:] - c_sorted[:, :-1] - diffs = diffs[diffs > 0] - step = diffs.median() if diffs.numel() > 0 else 1.0 - steps.append(step) - - grid_coords = torch.stack( - [ - torch.round(zero_c[..., 0] / steps[0]), - torch.round(zero_c[..., 1] / steps[1]), - ], - dim=-1, - ) - return grid_coords - - def _generate_spatial_mask(self, rel_coords): - """Generates distance-based masks for neighborhood attention. + if self.interactions >= VALID_INTERACTIONS: + return None # everything enabled, skip masking + + total = p + s + # Start fully blocked + mask = torch.ones(total, total, dtype=torch.bool, device=device) + + if "p2p" in self.interactions: + mask[:p, :p] = False + if "p2h" in self.interactions: + mask[:p, p:] = False + if "h2p" in self.interactions: + mask[p:, :p] = False + if "h2h" in self.interactions: + mask[p:, p:] = False + + # Always allow self-attention (diagonal) + mask.fill_diagonal_(False) + return mask - Args: - rel_coords (torch.Tensor): Relative coordinates (B, S, 2). + def get_sparsity_loss(self): + """Computes L1 norm of reconstruction weights for sparsity regularization. Returns: - torch.Tensor: Boolean mask (B, S) where True means ignore. + torch.Tensor: L1 loss value. """ - if self.mask_radius is None: - return None - - # Calculate Euclidean distances from center (0, 0) - dists = torch.norm(rel_coords, dim=-1) - - # Mask elements that are beyond the specified radius - mask = dists > self.mask_radius - return mask + return torch.norm(self.gene_reconstructor.weight, p=1) - def forward(self, x, rel_coords=None, return_pathways=False): + def forward( + self, + x, + rel_coords=None, + return_pathways=False, + mask=None, + return_dense=False, + return_attention=False, + ): """Main inference path. Args: x (torch.Tensor): Image data or pre-computed features. - (B, 3, H, W): Single image patch. - - (B, S, 3, H, W): Image neighborhood. - (B, S, D): Pre-computed features. rel_coords (torch.Tensor, optional): Spatial relative coordinates. return_pathways (bool): Whether to return pathway activations. + mask (torch.Tensor, optional): Boolean padding mask for patches (B, S) where True = Padding. + return_dense (bool): If True, returns per-patch gene predictions instead of global predictions. + return_attention (bool): If True, returns attention maps from all layers. Returns: - torch.Tensor: Predicted gene counts (B, num_genes). - (Optional) torch.Tensor: Pathway activations (B, num_pathways). + torch.Tensor: Predicted gene counts (B, num_genes) or (B, N, num_genes) if return_dense. + (Optional) torch.Tensor: Pathway activations/scores. + (Optional) list[torch.Tensor]: Attention maps [L, B, H, T, T] if return_attention. """ - if x.dim() == 5: - # Neighborhood Mode: Extract features for all patches - b, s, c, h, w = x.shape - x_flat = x.view(b * s, c, h, w) - features_flat = self.backbone(x_flat) - features = features_flat.view(b, s, -1) - elif x.dim() == 4: - # Single Patch Mode + if x.dim() == 4: + # Single Patch Mode: (B, C, H, W) features = self.backbone(x).unsqueeze(1) b, s = features.shape[0], 1 else: - # Assumed pre-computed or pre-reshaped features (B, S, D) + # Pre-computed features: (B, S, D) features = x - if features.dim() == 2: - features = features.unsqueeze(1) b, s = features.shape[0], features.shape[1] # 1. Project features into latent interaction space memory = self.image_proj(features) # 1b. Inject Spatial Positional Encodings - if self.use_spatial_pe and rel_coords is not None: - # Add spatial PE to visual features + if self.use_spatial_pe: + if rel_coords is None: + raise ValueError( + "use_spatial_pe is True, but rel_coords was not provided. " + "Ensure the dataloader passes spatial coordinates." + ) pe = self.spatial_encoder(rel_coords) memory = memory + pe - # 1c. Local Patch Mixing (Conv) - if ( - hasattr(self, "early_mixer") - and self.early_mixer is not None - and rel_coords is not None - ): - # Enforce mixing of histology features before they attend to pathways - # Ensure coords are normalized to grid indices for LocalPatchMixer - grid_coords = self._normalize_coords(rel_coords) - memory = self.early_mixer(memory, grid_coords) - # 2. Retrieve learnable pathway tokens - tgt = self.pathway_tokenizer(b) + pathway_tokens = self.pathway_tokens.expand(b, -1, -1) # (B, P, D) + p = pathway_tokens.shape[1] - # 3. Process Interactions (Unified Early Fusion Path) - out = self.fusion_engine(p_tokens=tgt, h_tokens=memory) - - # 4. Project focused pathway tokens back to gene space - # We project each pathway token to a scalar activation, then use these - # activations to reconstruct gene expression levels (the bottleneck). - pathway_activations = self.pathway_activator(out).squeeze( - -1 - ) # (B, num_pathways) - gene_expression = self.gene_reconstructor(pathway_activations) # (B, num_genes) - - if return_pathways: - return gene_expression, pathway_activations - return gene_expression + # 3. Process Interactions (Standard ViT sequence: [Pathways, Patches]) + sequence = torch.cat([pathway_tokens, memory], dim=1) # (B, P + S, D) - def forward_dense(self, x, mask=None, return_pathways=False, coords=None): - """Predicts individualized gene counts for every patch in a slide. - - Optimized for dense prediction on whole slide images via global context. - - Args: - x (torch.Tensor): Precomputed features for N patches (B, N, D). - mask (torch.Tensor, optional): Boolean padding mask (B, N) where True = Padding. - return_pathways (bool): Whether to return pathway scores. - coords (torch.Tensor, optional): Absolute coordinates (B, N, 2) for PE. - - Returns: - torch.Tensor: Contextualized predictions for each patch (B, N, G). - (Optional) torch.Tensor: Pathway scores (B, N, P). - """ - if x.dim() == 2: - x = x.unsqueeze(0) + # Build attention mask from configured interactions + interaction_mask = self._build_interaction_mask(p, s, sequence.device) - b = x.shape[0] # Dynamic batch size + # If sparse/padded inputs, mask out padding so it doesn't attend + pad_mask = None + if mask is not None: + # Pad pathway tokens with False (don't ignore) + pad_mask = torch.cat( + [torch.zeros(b, p, dtype=torch.bool, device=mask.device), mask], dim=1 + ) - # 1. Project to latent space - memory = self.image_proj(x) + attentions = [] + if return_attention: + # Manual forward through fusion_engine layers to extract weights + # Standard nn.TransformerEncoder suppresses weights for performance. + x_layer = sequence + for layer in self.fusion_engine.layers: + # Multi-head attention bit + # We need to call the internal self_attn with need_weights=True + attn_output, attn_weights = layer.self_attn( + x_layer, + x_layer, + x_layer, + attn_mask=interaction_mask, + key_padding_mask=pad_mask, + need_weights=True, + ) + attentions.append(attn_weights) + + # Rest of the layer (as per nn.TransformerEncoderLayer) + if layer.norm_first: + x_layer = x_layer + layer._sa_block( + layer.norm1(x_layer), interaction_mask, pad_mask + ) + x_layer = x_layer + layer._ff_block(layer.norm2(x_layer)) + else: + x_layer = layer.norm1( + x_layer + layer._sa_block(x_layer, interaction_mask, pad_mask) + ) + x_layer = layer.norm2(x_layer + layer._ff_block(x_layer)) + out = x_layer + else: + out = self.fusion_engine( + sequence, mask=interaction_mask, src_key_padding_mask=pad_mask + ) - # 1b. Inject Spatial Positional Encodings (Global) - if self.use_spatial_pe and coords is not None: - pe = self.spatial_encoder(coords) - memory = memory + pe + # Extract focused pathway tokens + processed_pathway_tokens = out[:, :p, :] # (B, P, D) - # 1c. Local Patch Mixing (Conv) - if ( - hasattr(self, "early_mixer") - and self.early_mixer is not None - and coords is not None - ): - grid_coords = self._normalize_coords(coords) - memory = self.early_mixer(memory, grid_coords) - - # 2. Retrieve learnable pathway tokens (global context) - pathway_tokens = self.pathway_tokenizer(b) # (B, P, D) - - # 3. Perform dense interaction (Global Patch-to-Patch + Pathway context) - # The MultimodalFusion/NystromEncoder class handles the concatenation and quadrant masking. - all_tokens = self.fusion_engine( - p_tokens=pathway_tokens, h_tokens=memory, return_all_tokens=True - ) + # Extract processed patch tokens + processed_patch_tokens = out[:, p:, :] # (B, S, D) - # Sliced Histology tokens: indices [Np:] - np = pathway_tokens.shape[1] - context_features = all_tokens[:, np:, :] + # 5. Compute pathway scores via cosine similarity with learnable temperature + # L2-normalize both sets of tokens to produce cosine similarities in [-1, 1] + norm_pathway = F.normalize(processed_pathway_tokens, dim=-1) # (B, P, D) + temperature = self.log_temperature.exp() # scalar - # 3b. Local GNN Refinement - if self.late_refiner is not None and coords is not None: - # Explicit skip connection: Inject raw spatial/visual memory back into contextualized tokens - context_features = self.late_refiner( - context_features + memory, coords, mask=mask + if return_dense: + # Dense prediction: per-patch cosine similarity with pathway tokens + norm_patch = F.normalize(processed_patch_tokens, dim=-1) # (B, S, D) + # (B, S, D) @ (B, D, P) -> (B, S, P) + pathway_scores = ( + torch.matmul(norm_patch, norm_pathway.transpose(1, 2)) * temperature ) + else: + # Global prediction: pool patches first, then compute scores + global_patch_token = processed_patch_tokens.mean( + dim=1, keepdim=True + ) # (B, 1, D) + norm_global = F.normalize(global_patch_token, dim=-1) # (B, 1, D) + pathway_scores = ( + torch.matmul(norm_global, norm_pathway.transpose(1, 2)) * temperature + ) + pathway_scores = pathway_scores.squeeze(1) # (B, P) + + # Gene reconstruction (unified for both modes) + if self.output_mode == "zinb": + mu = F.softplus(self.gene_reconstructor(pathway_scores)) + 1e-6 + mu = torch.clamp(mu, max=1e5) + pi = torch.sigmoid(self.pi_reconstructor(pathway_scores)) + theta = F.softplus(self.theta_reconstructor(pathway_scores)) + 1e-6 + gene_expression = (pi, mu, theta) + else: + gene_expression = self.gene_reconstructor(pathway_scores) - # 4. Dense prediction head (Pathway Bottleneck enforcement) - # Project updated patch tokens onto pathway embeddings: (B, N, D) @ (B, D, P) -> (B, N, P) - # We scale by 1/sqrt(D) to maintain reasonable activation variance, as this is effectively attention. - pathway_scores = torch.matmul( - context_features, pathway_tokens.transpose(1, 2) - ) / (context_features.shape[-1] ** 0.5) - - # Reconstruct Genes: (B, N, P) @ (P, G) -> (B, N, G) - # Reuses the gene_reconstructor for biological consistency. + results = [gene_expression] if return_pathways: - return self.gene_reconstructor(pathway_scores), pathway_scores - return self.gene_reconstructor(pathway_scores) + results.append(pathway_scores) + if return_attention: + results.append(attentions) + + if len(results) == 1: + return results[0] + return tuple(results) diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index 1bfb65c..f8783cf 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -253,17 +253,18 @@ def plot_training_summary( if len(label) > 30: label = label[:27] + "..." - truth_raw = pathway_truth[:, pw_idx] - pred_raw = pathway_pred[:, pw_idx] + truth_vals = pathway_truth[:, pw_idx] + pred_vals = pathway_pred[:, pw_idx] - # Use absolute bounds instead of per-plot z-scored normalized bounds - vmin = min(truth_raw.min(), pred_raw.min()) - vmax = max(truth_raw.max(), pred_raw.max()) + # Both truth and pred are now in the same units (mean log1p expression + # of pathway member genes), so shared bounds give a fair comparison + vmin = min(truth_vals.min(), pred_vals.min()) + vmax = max(truth_vals.max(), pred_vals.max()) sc = None for col, vals, suffix in [ - (col_gt, truth_raw, "Truth"), - (col_pred, pred_raw, "Pred"), + (col_gt, truth_vals, "Truth"), + (col_pred, pred_vals, "Pred"), ]: ax = fig.add_subplot(inner[row, col]) ax.set_facecolor("#0d0d1a") @@ -331,11 +332,9 @@ def main(): parser.add_argument( "--n-neighbors", type=int, default=0, help="Number of spatial neighbors to use" ) - parser.add_argument( - "--use-nystrom", - action="store_true", - help="Use Nystrom attention for linear complexity", - ) + parser.add_argument("--token-dim", type=int, default=256) + parser.add_argument("--n-heads", type=int, default=4) + parser.add_argument("--n-layers", type=int, default=2) parser.add_argument( "--num-pathways", type=int, @@ -351,6 +350,13 @@ def main(): parser.add_argument( "--plot-pathways", action="store_true", help="Visualize pathway activations" ) + parser.add_argument( + "--loss", + type=str, + default="mse", + choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], + help="Loss function used for training (needed for model reconstruction)", + ) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -372,9 +378,12 @@ def main(): elif args.model_type == "interaction": model = SpatialTranscriptFormer( num_genes=args.num_genes, - use_nystrom=args.use_nystrom, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, num_pathways=args.num_pathways, backbone_name=args.backbone, + output_mode="zinb" if args.loss == "zinb" else "counts", ) state_dict = torch.load(args.model_path, map_location=device, weights_only=True) @@ -460,7 +469,7 @@ def main(): output = model( images, rel_coords=rel_coords, return_pathways=args.plot_pathways ) - if isinstance(output, tuple): + if isinstance(output, tuple) and args.plot_pathways: preds = output[0] pathways = output[1] all_pathways.append(pathways.cpu().numpy()) @@ -469,6 +478,10 @@ def main(): else: preds = model(images) + # Unpack ZINB tuple if generated + if isinstance(preds, tuple): + preds = preds[1] # Use mean component + all_preds.append(preds.cpu().numpy()) all_truth.append(targets.numpy()) diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index 3ad8783..b24c4f5 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -19,6 +19,7 @@ PCCLoss, CompositeLoss, MaskedMSELoss, + ZINBLoss, ) from spatial_transcript_former.training.engine import train_one_epoch, validate from spatial_transcript_former.training.experiment_logger import ExperimentLogger @@ -88,18 +89,14 @@ def setup_model(args, device): num_genes=args.num_genes, backbone_name=args.backbone, pretrained=args.pretrained, - use_nystrom=args.use_nystrom, - mask_radius=args.mask_radius, - masked_quadrants=args.masked_quadrants, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, num_pathways=args.num_pathways, pathway_init=pathway_init, use_spatial_pe=args.use_spatial_pe, - early_mixer=( - None if args.early_mixer.lower() == "none" else args.early_mixer - ), - late_refiner=( - None if args.late_refiner.lower() == "none" else args.late_refiner - ), + output_mode="zinb" if args.loss == "zinb" else "counts", + interactions=getattr(args, "interactions", None), ) elif args.model == "attention_mil": from spatial_transcript_former.models.mil import AttentionMIL @@ -133,19 +130,34 @@ def setup_model(args, device): return model -def setup_criterion(args): - """Create loss function from CLI args.""" +def setup_criterion(args, pathway_init=None): + """Create loss function from CLI args. + + If ``pathway_init`` is provided and ``--pathway-loss-weight > 0``, + wraps the base criterion with :class:`AuxiliaryPathwayLoss`. + """ if args.loss == "pcc": - return PCCLoss() + base = PCCLoss() elif args.loss == "mse_pcc": - return CompositeLoss(alpha=args.pcc_weight) + base = CompositeLoss(alpha=args.pcc_weight) + elif args.loss == "zinb": + base = ZINBLoss() elif args.loss == "poisson": - return nn.PoissonNLLLoss(log_input=True) + base = nn.PoissonNLLLoss(log_input=True) elif args.loss == "logcosh": print("Using HuberLoss as proxy for LogCosh") - return nn.HuberLoss() + base = nn.HuberLoss() else: - return MaskedMSELoss() + base = MaskedMSELoss() + + pw_weight = getattr(args, "pathway_loss_weight", 0.0) + if pathway_init is not None and pw_weight > 0: + from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss + + print(f"Wrapping criterion with AuxiliaryPathwayLoss (lambda={pw_weight})") + return AuxiliaryPathwayLoss(pathway_init, base, lambda_pathway=pw_weight) + + return base # --------------------------------------------------------------------------- @@ -248,7 +260,7 @@ def parse_args(): "--loss", type=str, default="mse", - choices=["mse", "pcc", "mse_pcc", "poisson", "logcosh"], + choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], ) parser.add_argument( "--pcc-weight", @@ -256,6 +268,12 @@ def parse_args(): default=1.0, help="Weight for PCC term in mse_pcc loss", ) + parser.add_argument( + "--pathway-loss-weight", + type=float, + default=0.0, + help="Weight for auxiliary pathway PCC loss (0 = disabled)", + ) # Model g = parser.add_argument_group("Model") @@ -269,26 +287,21 @@ def parse_args(): g.add_argument("--no-pretrained", action="store_false", dest="pretrained") g.set_defaults(pretrained=True) g.add_argument("--num-pathways", type=int, default=50) - g.add_argument("--use-nystrom", action="store_true") - g.add_argument("--mask-radius", type=float, default=None) + g.add_argument("--token-dim", type=int, default=256) + g.add_argument("--n-heads", type=int, default=4) + g.add_argument("--n-layers", type=int, default=2) g.add_argument( "--no-spatial-pe", action="store_false", dest="use_spatial_pe", - help="Disable Spatial Positional Encoding", - ) - g.set_defaults(use_spatial_pe=True) - g.add_argument( - "--early-mixer", - type=str, - default="conv", - help="Early spatial mixer ('conv' or 'none')", + help="Disable spatial positional encoding", ) + g.set_defaults(use_spatial_pe=False) g.add_argument( - "--late-refiner", - type=str, - default="none", - help="Late spatial refiner ('gnn' or 'none')", + "--interactions", + nargs="+", + default=None, + help="Attention interactions to enable: p2p, p2h, h2p, h2h (default: all)", ) # Training @@ -302,6 +315,7 @@ def parse_args(): "--lr", type=float, default=get_config("training.learning_rate", 1e-4) ) g.add_argument("--weight-decay", type=float, default=0.0) + g.add_argument("--warmup-epochs", type=int, default=10) g.add_argument("--sparsity-lambda", type=float, default=0.0) g.add_argument("--augment", action="store_true") g.add_argument("--use-amp", action="store_true") @@ -319,7 +333,6 @@ def parse_args(): g.add_argument("--use-global-context", action="store_true") g.add_argument("--global-context-size", type=int, default=128) g.add_argument("--compile-backend", type=str, default="inductor") - g.add_argument("--masked-quadrants", type=str, nargs="+", default=["H2H"]) g.add_argument("--plot-pathways", action="store_true") g.add_argument( "--weak-supervision", action="store_true", help="Bag-level training for MIL" @@ -373,12 +386,29 @@ def main(): # 2. Model, Loss, Optimizer model = setup_model(args, device) - criterion = setup_criterion(args) + # Pass pathway_init to criterion so AuxiliaryPathwayLoss can use it + pathway_init = getattr(model, "_pathway_init_matrix", None) + criterion = setup_criterion(args, pathway_init=pathway_init).to(device) optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) + + # LR scheduler: cosine annealing with optional linear warmup + warmup_epochs = args.warmup_epochs + cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs - warmup_epochs, eta_min=1e-6 + ) + + def lr_lambda(epoch): + if epoch < warmup_epochs: + return epoch / max(warmup_epochs, 1) + return 1.0 # cosine scheduler handles the rest + + warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + scaler = torch.amp.GradScaler("cuda") if args.use_amp else None print(f"Loss: {criterion.__class__.__name__}") + print(f"LR schedule: {warmup_epochs}-epoch warmup → cosine annealing to 1e-6") # 3. Output & Logger os.makedirs(args.output_dir, exist_ok=True) @@ -418,14 +448,28 @@ def main(): ) val_loss = val_metrics["val_loss"] - print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") + print( + f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}" + ) + + # Step LR scheduler + if epoch < warmup_epochs: + warmup_scheduler.step() + else: + cosine_scheduler.step() # Log epoch - epoch_row = {"train_loss": train_loss, "val_loss": val_loss} + epoch_row = { + "train_loss": train_loss, + "val_loss": val_loss, + "lr": optimizer.param_groups[0]["lr"], + } if val_metrics.get("val_mae") is not None: epoch_row["val_mae"] = round(val_metrics["val_mae"], 4) if val_metrics.get("val_pcc") is not None: epoch_row["val_pcc"] = round(val_metrics["val_pcc"], 4) + if val_metrics.get("pred_variance") is not None: + epoch_row["pred_variance"] = round(val_metrics["pred_variance"], 6) if val_metrics.get("attn_correlation") is not None: epoch_row["attn_correlation"] = round(val_metrics["attn_correlation"], 4) logger.log_epoch(epoch + 1, epoch_row) diff --git a/src/spatial_transcript_former/training/engine.py b/src/spatial_transcript_former/training/engine.py index ed26c79..7319b48 100644 --- a/src/spatial_transcript_former/training/engine.py +++ b/src/spatial_transcript_former/training/engine.py @@ -9,6 +9,7 @@ import torch.nn as nn from tqdm import tqdm from spatial_transcript_former.models import SpatialTranscriptFormer +from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss def _optimizer_step( @@ -74,11 +75,29 @@ def train_one_epoch( ) with torch.amp.autocast("cuda", enabled=scaler is not None): - if hasattr(model, "forward_dense") and not getattr( + if isinstance(model, SpatialTranscriptFormer) and not getattr( model, "weak_supervision", False ): - preds = model.forward_dense(feats, mask=mask, coords=coords) - loss = criterion(preds, genes, mask=mask) + # Request pathway scores if criterion can use them + needs_pathways = isinstance(criterion, AuxiliaryPathwayLoss) + output = model( + feats, + return_dense=True, + mask=mask, + rel_coords=coords, + return_pathways=needs_pathways, + ) + if needs_pathways: + preds, pathway_preds = output + loss = criterion( + preds, + genes, + mask=mask, + pathway_preds=pathway_preds, + ) + else: + preds = output + loss = criterion(preds, genes, mask=mask) else: preds = model(feats) bag_target = _compute_bag_target(genes, mask) @@ -136,6 +155,7 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) running_loss = 0.0 running_mae = 0.0 pcc_list = [] + pred_var_list = [] attn_correlations = [] with torch.no_grad(): @@ -157,10 +177,21 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) attn = None if whole_slide: - if hasattr(model, "forward_dense") and not getattr( + if isinstance(model, SpatialTranscriptFormer) and not getattr( model, "weak_supervision", False ): - outputs = model.forward_dense(feats, mask=mask, coords=coords) + needs_pathways = isinstance(criterion, AuxiliaryPathwayLoss) + output = model( + feats, + return_dense=True, + mask=mask, + rel_coords=coords, + return_pathways=needs_pathways, + ) + if needs_pathways: + outputs, pathway_preds = output + else: + outputs = output targets = genes else: # MIL models: extract attention if supported @@ -179,33 +210,50 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) else: outputs = model(images) - loss = ( - criterion(outputs, targets, mask=mask) - if whole_slide - and hasattr(model, "forward_dense") + # Compute loss, passing pathway_preds if available + if ( + whole_slide + and isinstance(model, SpatialTranscriptFormer) and not getattr(model, "weak_supervision", False) - else criterion(outputs, targets) - ) + ): + if isinstance(criterion, AuxiliaryPathwayLoss): + loss = criterion( + outputs, + targets, + mask=mask, + pathway_preds=pathway_preds, + ) + else: + loss = criterion(outputs, targets, mask=mask) + else: + loss = criterion(outputs, targets) # --- Interpretability Metrics (MAE & PCC) --- - mae_diff = torch.abs(outputs - targets) + # Since ZINB loss outputs a tuple (pi, mu, theta), we only use mu (index 1) for evaluations against truth. + eval_preds = outputs[1] if isinstance(outputs, tuple) else outputs + mae_diff = torch.abs(eval_preds - targets) if ( whole_slide - and hasattr(model, "forward_dense") + and isinstance(model, SpatialTranscriptFormer) and not getattr(model, "weak_supervision", False) and mask is not None ): valid_mask = ~mask.unsqueeze(-1).expand_as(mae_diff) mae_val = (mae_diff * valid_mask.float()).sum() / valid_mask.sum() + else: + mae_val = mae_diff.mean() - if torch.isfinite(outputs).all() and torch.isfinite(targets).all(): + if ( + torch.isfinite(eval_preds).all() + and torch.isfinite(targets).all() + ): # Calculate Spatial PCC (across spots N, for each gene G independently) # outputs/targets are (B, N, G) for whole_slide or (B, G) for patch if whole_slide: # Iterate over batches to correlate spatially for each slide - B = outputs.shape[0] + B = eval_preds.shape[0] for b_idx in range(B): - p_slide = outputs[b_idx] # (N, G) + p_slide = eval_preds[b_idx] # (N, G) t_slide = targets[b_idx] # (N, G) valid_idx = ~mask[b_idx] @@ -231,7 +279,7 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) pcc_list.append(valid_corrs.mean().item()) else: # Patch level (B, G). Correlate across the batch B (which is spatial patches) - vx = outputs - outputs.mean(dim=0, keepdim=True) + vx = eval_preds - eval_preds.mean(dim=0, keepdim=True) vy = targets - targets.mean(dim=0, keepdim=True) num = torch.sum(vx * vy, dim=0) den = torch.sqrt( @@ -261,6 +309,23 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) running_loss += loss.item() running_mae += mae_val.item() + # Track prediction variance (collapse detector) + with torch.no_grad(): + if ( + whole_slide + and mask is not None + and isinstance(model, SpatialTranscriptFormer) + and not getattr(model, "weak_supervision", False) + ): + for b in range(eval_preds.shape[0]): + valid = ~mask[b] + if valid.sum() >= 2: + pred_var_list.append( + eval_preds[b, valid].var(dim=0).mean().item() + ) + else: + pred_var_list.append(eval_preds.var(dim=0).mean().item()) + avg_loss = running_loss / len(loader) avg_mae = running_mae / len(loader) avg_pcc = sum(pcc_list) / len(pcc_list) if pcc_list else None @@ -268,8 +333,12 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) sum(attn_correlations) / len(attn_correlations) if attn_correlations else None ) + avg_pred_var = sum(pred_var_list) / len(pred_var_list) if pred_var_list else None + if avg_pcc is not None: print(f"Validation MAE: {avg_mae:.4f} | PCC: {avg_pcc:.4f}") + if avg_pred_var is not None: + print(f"Prediction Variance: {avg_pred_var:.6f}") if avg_corr is not None: print(f"Spatial Attention Correlation: {avg_corr:.4f}") @@ -277,5 +346,6 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) "val_loss": avg_loss, "val_mae": avg_mae, "val_pcc": avg_pcc, + "pred_variance": avg_pred_var, "attn_correlation": avg_corr, } diff --git a/src/spatial_transcript_former/training/losses.py b/src/spatial_transcript_former/training/losses.py index ffdfae0..e9a3bc2 100644 --- a/src/spatial_transcript_former/training/losses.py +++ b/src/spatial_transcript_former/training/losses.py @@ -37,25 +37,69 @@ def forward(self, preds, target, mask=None): Returns: Scalar loss = 1 - mean(PCC). """ - if preds.dim() == 3: - B, N, G = preds.shape - preds = preds.reshape(-1, G) # (B*N, G) - target = target.reshape(-1, G) + if preds.dim() == 2: + preds = preds.unsqueeze(1) # (B, 1, G) + target = target.unsqueeze(1) # (B, 1, G) + if mask is not None: + mask = mask.unsqueeze(1) # (B, 1) + + B, N, G = preds.shape + # If N == 1 (e.g., standard patch-wise prediction without context), + # PCC across a spatial dimension of 1 is undefined (variance is 0). + # We fallback to batch-wise correlation in this specific edge case. + if N == 1: + preds = preds.squeeze(1) # (B, G) + target = target.squeeze(1) # (B, G) if mask is not None: - valid = ~mask.reshape(-1) # (B*N,) + valid = ~mask.squeeze(1) # (B) preds = preds[valid] target = target[valid] - # Centre per-spot - vx = preds - preds.mean(dim=0, keepdim=True) - vy = target - target.mean(dim=0, keepdim=True) + if preds.shape[0] < 2: + return torch.tensor(0.0, device=preds.device, requires_grad=True) + + vx = preds - preds.mean(dim=0, keepdim=True) + vy = target - target.mean(dim=0, keepdim=True) + cost = torch.sum(vx * vy, dim=0) / ( + torch.sqrt(torch.sum(vx**2, dim=0) + self.eps) + * torch.sqrt(torch.sum(vy**2, dim=0) + self.eps) + ) + return 1 - cost.mean() + + # 1. Masking: Zero out padded positions so they don't contribute to sums + if mask is not None: + valid = ~mask.unsqueeze(-1) # (B, N, 1) + preds = preds * valid.float() + target = target * valid.float() + # valid_counts: (B, 1, 1) to enable broadcasting across N (dim 1) and G (dim 2) + valid_counts = valid.sum(dim=1, keepdim=True).clamp(min=1.0) + else: + valid_counts = torch.tensor(N, dtype=torch.float32, device=preds.device) - # Correlation - cost = torch.sum(vx * vy, dim=0) / ( - torch.sqrt(torch.sum(vx**2, dim=0) + self.eps) - * torch.sqrt(torch.sum(vy**2, dim=0) + self.eps) - ) + # 2. Centre per-slide (over spatial dimension N) + # Calculate mean using valid counts to avoid skew from zeros + pred_means = preds.sum(dim=1, keepdim=True) / valid_counts + target_means = target.sum(dim=1, keepdim=True) / valid_counts + + vx = preds - pred_means + vy = target - target_means + + if mask is not None: + vx = vx * valid.float() + vy = vy * valid.float() + + # 3. Correlation (covariance / (std_x * std_y)) per slide, per gene + # sum over spatial dimension N -> (B, G) + cov = torch.sum(vx * vy, dim=1) + var_x = torch.sum(vx**2, dim=1) + var_y = torch.sum(vy**2, dim=1) + + cost = cov / ( + torch.sqrt(var_x + self.eps) * torch.sqrt(var_y + self.eps) + ) # (B, G) + + # Average across genes, then average across batch return 1 - cost.mean() @@ -143,3 +187,151 @@ def forward(self, preds, target, mask=None): return (loss * valid.float()).sum() / valid.sum() return loss.mean() + + +class ZINBLoss(nn.Module): + """ + Zero-Inflated Negative Binomial (ZINB) Loss. + + Designed for highly dispersed, zero-inflated count data (like raw RNA-seq). + Requires the model to output three parameters per gene: + - pi: probability of zero inflation (dropout) + - mu: mean of the negative binomial distribution + - theta: inverse dispersion parameter + """ + + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, preds, target, mask=None): + """ + Args: + preds: Tuple of (pi, mu, theta), each (B, G) or (B, N, G). + pi should be [0, 1] (e.g. via Sigmoid). + mu, theta should be > 0 (e.g. via Softplus or Exp). + target: Raw integer counts (B, G) or (B, N, G). + mask: (B, N) boolean, True = padded (ignore). Optional. + + Returns: + Scalar negative log-likelihood over valid positions. + """ + pi, mu, theta = preds + + # Add numerical stability to theta and mu + theta = torch.clamp(theta, min=self.eps, max=1e6) + mu = torch.clamp(mu, min=self.eps, max=1e6) + + # Ensure target is valid for ZINB (no negatives, we expect counts) + target = torch.clamp(target, min=0) + + # 1. Negative Binomial Probability + # NB(y; mu, theta) = Gamma(y+theta)/(Gamma(theta)*y!) * (theta/(theta+mu))^theta * (mu/(theta+mu))^y + # Log version for stability using lgamma: + # ln(Gamma(target + theta)) - ln(Gamma(theta)) - ln(Gamma(target + 1)) + # + theta * ln(theta / (theta + mu)) + target * ln(mu / (theta + mu)) + + t1 = ( + torch.lgamma(target + theta) + - torch.lgamma(theta) + - torch.lgamma(target + 1) + ) + t2 = theta * (torch.log(theta + self.eps) - torch.log(theta + mu + self.eps)) + t3 = target * (torch.log(mu + self.eps) - torch.log(theta + mu + self.eps)) + + nb_log_prob = t1 + t2 + t3 + + # 2. Zero Inflation Mask + is_zero = (target == 0).float() + + # 3. ZINB Log Likelihood + # 3a. If target == 0: ln(pi + (1-pi) * NB(0; mu, theta)) + # NB(0; mu, theta) = (theta / (theta + mu))^theta + # Log space for stability: theta * (log(theta) - log(theta + mu)) + nb_zero_log_prob = theta * ( + torch.log(theta + self.eps) - torch.log(theta + mu + self.eps) + ) + + # zero_case_prob = pi + (1 - pi) * exp(nb_zero_log_prob) + # We need ln(pi + (1-pi)*exp(nb_zero_log_prob)). + # clamp heavily before taking log to prevent NaNs when pi -> 0 and exp(nb) -> 0 + zero_case_prob = pi + (1 - pi) * torch.exp(nb_zero_log_prob) + zero_case_prob = torch.clamp(zero_case_prob, min=self.eps, max=1.0) + zero_case_log_prob = torch.log(zero_case_prob) + + # 3b. If target > 0: ln((1-pi) * NB(target; mu, theta)) + # = ln(1-pi) + ln(NB) + non_zero_case_log_prob = torch.log(1 - pi + self.eps) + nb_log_prob + + # Combine + log_likelihood = ( + is_zero * zero_case_log_prob + (1 - is_zero) * non_zero_case_log_prob + ) + + if mask is not None and pi.dim() == 3: + # Expand mask to gene dimension: (B, N) -> (B, N, G) + valid = ~mask.unsqueeze(-1).expand_as(log_likelihood) + return -(log_likelihood * valid.float()).sum() / valid.sum() + + return -log_likelihood.mean() + + +class AuxiliaryPathwayLoss(nn.Module): + """Wrapper combining gene-level loss with PCC-based pathway auxiliary loss. + + Directly supervises the pathway bottleneck scores using PCC against + pathway ground truth computed from gene expression via MSigDB membership. + This provides a direct gradient signal to the pathway tokens, preventing + bottleneck collapse and ensuring each pathway learns its correct spatial + activation pattern. + + The total loss is:: + + L = L_gene + lambda * (1 - PCC(pathway_scores, target_pathways)) + + where ``target_pathways = target_genes @ M.T`` and M is the MSigDB + membership matrix. + """ + + def __init__(self, pathway_matrix, gene_criterion, lambda_pathway=0.5): + """ + Args: + pathway_matrix (torch.Tensor): Binary MSigDB membership matrix + of shape ``(P, G)`` where ``P`` is the number of pathways and + ``G`` is the number of genes. + gene_criterion (nn.Module): Base loss for gene prediction + (e.g. ``MaskedMSELoss``, ``CompositeLoss``). + lambda_pathway (float): Weight for the auxiliary pathway loss. + """ + super().__init__() + self.register_buffer("pathway_matrix", pathway_matrix.float()) + self.gene_criterion = gene_criterion + self.lambda_pathway = lambda_pathway + self.pcc = PCCLoss() + + def forward(self, gene_preds, target_genes, mask=None, pathway_preds=None): + """ + Args: + gene_preds: (B, G) or (B, N, G) predicted gene expression. + target_genes: (B, G) or (B, N, G) ground truth gene expression. + mask: (B, N) boolean, True = padded (ignore). Optional. + pathway_preds: (B, P) or (B, N, P) predicted pathway scores. + If None, only gene loss is computed (graceful fallback). + + Returns: + Scalar total loss. + """ + gene_loss = self.gene_criterion(gene_preds, target_genes, mask=mask) + + if pathway_preds is None or self.lambda_pathway == 0: + return gene_loss + + # Compute pathway ground truth from gene expression + # target_genes: (B, [N,] G), pathway_matrix: (P, G) + # result: (B, [N,] P) + with torch.no_grad(): + target_pathways = torch.matmul(target_genes, self.pathway_matrix.T) + + pathway_loss = self.pcc(pathway_preds, target_pathways, mask=mask) + + return gene_loss + self.lambda_pathway * pathway_loss diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index e7b7a39..1e3b367 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -152,23 +152,12 @@ def run_inference_plot(model, args, sample_id, epoch, device): h5_path = os.path.join(patches_dir, f"{sample_id}.h5") h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") - with h5py.File(h5_path, "r") as f: - patch_barcodes = f["barcode"][:].flatten() - coords = f["coords"][:] - # Load global genes try: common_gene_names = load_global_genes(args.data_dir, args.num_genes) except Exception: common_gene_names = None - gene_matrix, mask, gene_names = load_gene_expression_matrix( - h5ad_path, - patch_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - # Run inference preds = [] pathways_list = [] @@ -194,18 +183,53 @@ def run_inference_plot(model, args, sample_id, epoch, device): selected_gene_names=common_gene_names, n_neighbors=args.n_neighbors, whole_slide_mode=args.whole_slide, + log1p=log_transform, ) if args.whole_slide: - feats, _, _ = ds[0] + feats, gene_targets, slide_coords = ds[0] feats = feats.unsqueeze(0).to(device) - if hasattr(model, "forward_dense"): - output = model.forward_dense(feats, return_pathways=True) + slide_coords = slide_coords.unsqueeze(0).to(device) + if isinstance(model, SpatialTranscriptFormer): + output = model( + feats, + return_dense=True, + rel_coords=slide_coords, + return_pathways=True, + ) if isinstance(output, tuple): - preds.append(output[0].detach().cpu().squeeze(0)) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.detach().cpu().squeeze(0)) pathways_list.append(output[1].detach().cpu().squeeze(0)) else: preds.append(output.detach().cpu().squeeze(0)) + + # Use raw pixel coords from the .pt file for histology overlay + # These are guaranteed to be in the same order as the features + saved_data = torch.load( + feature_path, map_location="cpu", weights_only=True + ) + raw_coords = saved_data["coords"] # (N, 2) + raw_barcodes = saved_data["barcodes"] + del saved_data + + # Compute the same mask the dataset used to filter + _, pt_mask, gene_names = load_gene_expression_matrix( + h5ad_path, + raw_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + pt_mask_bool = np.array(pt_mask, dtype=bool) + coord_subset = raw_coords[pt_mask_bool].numpy() + + # Pathway truth from the dataset's aligned gene matrix + gene_truth = gene_targets.numpy() + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) else: dl = DataLoader(ds, batch_size=32, shuffle=False) for feats, _, rel_coords_batch in dl: @@ -217,13 +241,43 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) if isinstance(output, tuple): pathways_list.append(output[1].cpu()) - preds.append(output[0].cpu()) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.cpu()) else: preds.append(output.cpu()) else: preds.append(model(feats.to(device)).cpu()) + + # Non-whole-slide: use h5 file coords (same source as DataLoader) + with h5py.File(h5_path, "r") as f: + patch_barcodes = f["barcode"][:].flatten() + h5_coords = f["coords"][:] + gene_matrix, mask, gene_names = load_gene_expression_matrix( + h5ad_path, + patch_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + coord_mask = np.array(mask, dtype=bool) + coord_subset = h5_coords[coord_mask] + gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) else: - coord_subset = coords[mask] + with h5py.File(h5_path, "r") as f: + patch_barcodes = f["barcode"][:].flatten() + h5_coords = f["coords"][:] + gene_matrix, mask, gene_names = load_gene_expression_matrix( + h5ad_path, + patch_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + coord_mask = np.array(mask, dtype=bool) + coord_subset = h5_coords[coord_mask] ds = HEST_Dataset( h5_path, coord_subset, gene_matrix, indices=np.where(mask)[0] ) @@ -237,28 +291,32 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) if isinstance(output, tuple): pathways_list.append(output[1].cpu()) - preds.append(output[0].cpu()) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.cpu()) else: preds.append(output.cpu()) else: preds.append(model(imgs.to(device)).cpu()) + gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) - if not preds or not pathways_list: - print("Warning: No predictions/pathways generated. Skipping plot.") + if not preds: + print("Warning: No predictions generated. Skipping plot.") return - pathways = torch.cat(pathways_list, dim=0).numpy() - coord_mask = np.array(mask, dtype=bool) - coord_subset = coords[coord_mask] - - # Compute pathway ground truth from MSigDB membership (fixed across epochs) - gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix - pathway_truth, pathway_names = _compute_pathway_truth( - gene_truth, gene_names, args=args + # Compute pathway activations from gene predictions (same method as truth) + # Both truth and pred are now: mean gene expression of pathway members + gene_preds_np = torch.cat(preds, dim=0).numpy() + pathway_pred, _ = _compute_pathway_truth( + gene_preds_np, gene_names, args=args ) - if pathway_truth is None: - print("Warning: Could not compute pathway truth. Skipping plot.") + if pathway_truth is None or pathway_pred is None: + print("Warning: Could not compute pathway truth/pred. Skipping plot.") return # Load histology image @@ -270,7 +328,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) plot_training_summary( coord_subset, - pathways, + pathway_pred, pathway_truth, pathway_names, sample_id=sample_id, diff --git a/tests/test_bottleneck_arch.py b/tests/test_bottleneck_arch.py index a2effba..f2eea5e 100644 --- a/tests/test_bottleneck_arch.py +++ b/tests/test_bottleneck_arch.py @@ -16,9 +16,11 @@ def test_interaction_output_shape(): # Dummy input (Batch, Channel, Height, Width) x = torch.randn(4, 3, 224, 224) + # Single patch => S=1 + rel_coords = torch.randn(4, 1, 2) # Forward pass - output = model(x) + output = model(x, rel_coords=rel_coords) # Verify shape (Batch, num_genes) assert output.shape == ( @@ -32,8 +34,9 @@ def test_interaction_gradient_flow(): num_genes = 1000 model = SpatialTranscriptFormer(num_genes=num_genes) x = torch.randn(2, 3, 224, 224) + rel_coords = torch.randn(2, 1, 2) - output = model(x) + output = model(x, rel_coords=rel_coords) loss = output.sum() loss.backward() diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 1e86a27..08fe63b 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -27,22 +27,7 @@ def small_model(): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, - ) - - -@pytest.fixture -def small_nystrom_model(): - """A small SpatialTranscriptFormer with Nystrom attention.""" - return SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=True, - num_landmarks=32, + n_layers=2, ) @@ -85,8 +70,7 @@ def test_save_load_preserves_weights(self, small_model, checkpoint_dir): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, + n_layers=2, ) fresh_optimizer = optim.Adam(fresh_model.parameters(), lr=1e-4) @@ -125,8 +109,7 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, + n_layers=2, ) fresh_optimizer = optim.Adam(fresh_model.parameters(), lr=1e-4) fresh_scaler = torch.amp.GradScaler("cuda") @@ -151,75 +134,3 @@ def test_no_checkpoint_starts_fresh(self, small_model, checkpoint_dir): ) assert start_epoch == 0 assert best_val == float("inf") - - -# --------------------------------------------------------------------------- -# Architecture Mismatch -# --------------------------------------------------------------------------- - - -class TestArchitectureMismatch: - def test_nystrom_checkpoint_fails_on_standard( - self, small_nystrom_model, checkpoint_dir - ): - """Nystrom checkpoint should fail to load into standard model.""" - optimizer = optim.Adam(small_nystrom_model.parameters(), lr=1e-4) - save_checkpoint( - small_nystrom_model, - optimizer, - None, - epoch=5, - best_val_loss=0.3, - output_dir=checkpoint_dir, - model_name="interaction", - ) - - # Try loading into non-Nystrom model - standard_model = SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=False, - ) - standard_optimizer = optim.Adam(standard_model.parameters(), lr=1e-4) - - with pytest.raises(RuntimeError): - load_checkpoint( - standard_model, - standard_optimizer, - None, - checkpoint_dir, - "interaction", - "cpu", - ) - - def test_matching_architecture_loads(self, small_nystrom_model, checkpoint_dir): - """Nystrom checkpoint should load into matching Nystrom model.""" - optimizer = optim.Adam(small_nystrom_model.parameters(), lr=1e-4) - save_checkpoint( - small_nystrom_model, - optimizer, - None, - epoch=5, - best_val_loss=0.3, - output_dir=checkpoint_dir, - model_name="interaction", - ) - - fresh = SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=True, - num_landmarks=32, - ) - fresh_opt = optim.Adam(fresh.parameters(), lr=1e-4) - - start_epoch, _ = load_checkpoint( - fresh, fresh_opt, None, checkpoint_dir, "interaction", "cpu" - ) - assert start_epoch == 6 # Resumed correctly diff --git a/tests/test_interactions.py b/tests/test_interactions.py new file mode 100644 index 0000000..6847538 --- /dev/null +++ b/tests/test_interactions.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import sys +import os +import numpy as np + +# Add src to path +sys.path.append(os.path.abspath("src")) + +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer + + +def test_mask_logic(): + print("Testing Mask Logic...") + model = SpatialTranscriptFormer( + num_genes=100, num_pathways=10, interactions=["p2p", "p2h"] + ) + + # p=10, s=20 + mask = model._build_interaction_mask(10, 20, "cpu") + + # Check p2p (allowed) + assert mask[0, 1] == False, "p2p should be allowed" + # Check p2h (allowed) + assert mask[0, 15] == False, "p2h should be allowed" + # Check h2p (blocked) + assert mask[15, 0] == True, "h2p should be blocked" + # Check h2h (blocked) + assert mask[15, 16] == True, "h2h should be blocked" + + print("Mask logic test passed!") + + +def test_connectivity(): + print("Testing Patch-to-Patch Connectivity (h2h)...") + # Enable all interactions including h2h + model = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p", "p2h", "h2p", "h2h"], + use_spatial_pe=False, # Disable PE to verify raw attention connectivity + pretrained=False, + ) + model.eval() + + # dummy features (B=1, S=2, D=2048) + feats = torch.randn(1, 2, 2048, requires_grad=True) + coords = torch.randn(1, 2, 2) + + # We want to see if the output for patch 0 depends on the input of patch 1 + # We use return_dense=True to get per-patch gene predictions + output = model(feats, rel_coords=coords, return_dense=True) # (1, 2, 100) + + # Loss on patch 0 output + loss = output[0, 0].sum() + loss.backward() + + # Check if gradient flows to patch 1 input + grad_patch_1 = feats.grad[0, 1].norm() + print(f"Gradient at Patch 1 from Patch 0 output: {grad_patch_1.item():.6e}") + + assert ( + grad_patch_1 > 0 + ), "Patch 0 output should depend on Patch 1 input when h2h is enabled" + + # Now try with h2h disabled + print("Testing Connectivity with h2h disabled...") + model_no_h2h = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p", "p2h", "h2p"], + use_spatial_pe=False, + pretrained=False, + ) + model_no_h2h.eval() + + feats_2 = torch.randn(1, 2, 2048, requires_grad=True) + output_2 = model_no_h2h(feats_2, rel_coords=coords, return_dense=True) + + loss_2 = output_2[0, 0].sum() + loss_2.backward() + + grad_patch_1_no_h2h = feats_2.grad[0, 1].norm() + print( + f"Gradient at Patch 1 from Patch 0 output (no h2h): {grad_patch_1_no_h2h.item():.6e}" + ) + + # It should still be non-zero because patches interact via pathways [Patch 1 -> Pathway -> Patch 0] + assert ( + grad_patch_1_no_h2h > 0 + ), "Patch 0 should still interact with Patch 1 via pathways even if h2h is disabled" + + # To truly see zero interaction, block pathways too + print("Testing ZERO Connectivity (only p2p enabled)...") + model_isolated = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p"], + use_spatial_pe=False, + pretrained=False, + ) + model_isolated.eval() + feats_3 = torch.randn(1, 2, 2048, requires_grad=True) + output_3 = model_isolated(feats_3, rel_coords=coords, return_dense=True) + loss_3 = output_3[0, 0].sum() + loss_3.backward() + grad_patch_1_isolated = feats_3.grad[0, 1].norm() + print(f"Gradient at Patch 1 (fully isolated): {grad_patch_1_isolated.item():.6e}") + assert ( + grad_patch_1_isolated < 1e-10 + ), "Patch 0 should NOT depend on Patch 1 when only p2p is enabled" + + print("Connectivity tests passed!") + + +def test_attention_extraction(): + print("Testing Attention Extraction...") + p, s = 10, 20 + model = SpatialTranscriptFormer( + num_genes=100, + num_pathways=p, + n_layers=2, + interactions=["p2p", "p2h"], # Block h2p, h2h + pretrained=False, + ) + model.eval() + + feats = torch.randn(1, s, 2048) + coords = torch.randn(1, s, 2) + + # Forward with attention + _, attentions = model(feats, rel_coords=coords, return_attention=True) + + # attentions is list of weights [layers] + for i, attn in enumerate(attentions): + print(f"Testing Layer {i}...") + # attn is (B, T, T) + assert attn.shape == (1, p + s, p + s) + + # We expect blocked regions to have 0 attention + h2p_region = attn[0, p:, :p] + h2h_region = attn[0, p:, p:] + + # For h2h, we must ignore diagonal + h2h_off_diag = h2h_region.clone() + h2h_off_diag.fill_diagonal_(0) + + print(f"Layer {i} h2p attention max: {h2p_region.max().item():.2e}") + print(f"Layer {i} h2h off-diag attention max: {h2h_off_diag.max().item():.2e}") + + assert ( + h2p_region.max() < 1e-10 + ), f"Layer {i} h2p attention should be zero when blocked" + assert ( + h2h_off_diag.max() < 1e-10 + ), f"Layer {i} h2h attention should be zero when blocked" + + # Check that allowed regions have non-zero attention + p2p_region = attn[0, :p, :p] + p2h_region = attn[0, :p, p:] + print(f"Layer {i} p2p attention max: {p2p_region.max().item():.2e}") + print(f"Layer {i} p2h attention max: {p2h_region.max().item():.2e}") + + assert p2p_region.max() > 0, f"Layer {i} p2p attention should be non-zero" + assert p2h_region.max() > 0, f"Layer {i} p2h attention should be non-zero" + + print("Attention extraction test passed!") + + +if __name__ == "__main__": + test_mask_logic() + test_connectivity() + test_attention_extraction() diff --git a/tests/test_local_patch_mixer.py b/tests/test_local_patch_mixer.py deleted file mode 100644 index 20a5767..0000000 --- a/tests/test_local_patch_mixer.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import pytest -from spatial_transcript_former.models.interaction import LocalPatchMixer - - -def test_local_patch_mixer_basic(): - """Test that output shape is correct and forward pass works.""" - B, N, D = 2, 10, 32 - mixer = LocalPatchMixer(dim=D, kernel_size=3) - - # Grid of coordinates (2x5) - coords = [] - for b in range(B): - bc = [] - for y in range(2): - for x in range(5): - bc.append([x, y]) - coords.append(bc) - coords = torch.tensor(coords, dtype=torch.float32) - - x = torch.randn(B, N, D) - out = mixer(x, coords) - - assert out.shape == x.shape - # Check that residual connection worked (output is not zero) - assert not torch.allclose(out, torch.zeros_like(out)) - - -def test_local_patch_mixer_sparse(): - """Test that gaps in coordinates are handled.""" - dim = 8 - mixer = LocalPatchMixer(dim=dim, kernel_size=3) - - # 3 patches in a line with a gap: (0,0), (2,0), (4,0) - x = torch.zeros(1, 3, dim) - x[0, 0, 0] = 1.0 - x[0, 1, 1] = 1.0 - x[0, 2, 2] = 1.0 - - coords = torch.tensor([[[0, 0], [2, 0], [4, 0]]], dtype=torch.float32) - - with torch.no_grad(): - mixer.conv.weight.fill_(0.0) - mixer.conv.bias.fill_(0.0) - # Set weight to look at neighbor at x+1 (weight[1, 2] in 3x3) - mixer.conv.weight[:, 0, 1, 2] = 1.0 - - out = mixer(x, coords) - - # The patch at 2.0 looks at 3.0 (empty). Conv result should be 0. - # Out = x + 0 = x. - # Check that ALL values are exactly as original x - torch.testing.assert_close(out, x, atol=1e-4, rtol=1e-4) - - -def test_local_patch_mixer_neighbor_influence(): - """Verify that neighbors actually influence the center patch.""" - dim = 4 - mixer = LocalPatchMixer(dim=dim, kernel_size=3) - - # 3 patches adjacent: (0,0), (1,0), (2,0) - x = torch.zeros(1, 3, dim) - # Patch 1 (center) has signal in channel 0 - x[0, 1, 0] = 1.0 - - coords = torch.tensor([[[0, 0], [1, 0], [2, 0]]], dtype=torch.float32) - - with torch.no_grad(): - mixer.conv.weight.fill_(0.0) - mixer.conv.bias.fill_(0.0) - # Weight[1, 0] looks at neighbor x-1 - # So output at Grid(1,0) looks at Grid(0,0) - # Output at Grid(2,0) looks at Grid(1,0) - mixer.conv.weight[:, 0, 1, 0] = 10.0 - - out = mixer(x, coords) - - # Patch 2 (at 2,0) should receive signal from Patch 1 (at 1,0) - # Patch 1 had 1.0. Conv result at 2,0 should be 1.0 * 10.0 = 10.0. - # After GELU: GELU(10) is approx 10.0. - # Residual: 0.0 + 10.0 = 10.0. - assert out[0, 2, 0] > 9.0 - - # Patch 1 (at 1,0) should receive signal from Patch 0 (at 0,0) which is 0. - # Residual: 1.0 + 0.0 = 1.0. - assert out[0, 1, 0] == 1.0 - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_losses.py b/tests/test_losses.py index 8351286..e9ffbeb 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -1,5 +1,6 @@ """ -Tests for loss functions: MaskedMSELoss, PCCLoss, CompositeLoss. +Tests for loss functions: MaskedMSELoss, PCCLoss, CompositeLoss, +AuxiliaryPathwayLoss. Verifies correctness of masking, scale invariance, gradients, and composite objective decomposition. @@ -12,6 +13,7 @@ MaskedMSELoss, PCCLoss, CompositeLoss, + AuxiliaryPathwayLoss, ) # --------------------------------------------------------------------------- @@ -184,3 +186,294 @@ def test_different_alphas(self, tensors_2d): loss_high = CompositeLoss(alpha=10.0)(preds, target) # They should differ since PCC != 0 assert loss_low.item() != pytest.approx(loss_high.item(), abs=0.01) + + +# --------------------------------------------------------------------------- +# ZINBLoss +# --------------------------------------------------------------------------- + + +@pytest.fixture +def zinb_tensors(): + """Tensors for ZINB testing: pi, mu, theta, and counts.""" + torch.manual_seed(42) + B, G = 16, 50 + pi = torch.rand(B, G) # [0, 1] + mu = torch.rand(B, G) * 10 + 0.1 # > 0 + theta = torch.rand(B, G) * 5 + 0.1 # > 0 + + # Simulate some zero-inflated counts + counts = torch.poisson(mu) + dropout_mask = torch.rand(B, G) < 0.3 + counts[dropout_mask] = 0.0 + + return (pi, mu, theta), counts + + +class TestZINBLoss: + def test_basic_computation(self, zinb_tensors): + """ZINBLoss should compute a finite scalar loss.""" + from spatial_transcript_former.training.losses import ZINBLoss + + preds, target = zinb_tensors + loss_fn = ZINBLoss() + loss = loss_fn(preds, target) + + assert loss.isfinite() + assert loss.item() > 0 # NLL should be positive + + def test_gradient_flow(self, zinb_tensors): + """Gradients should flow through all three parameters.""" + from spatial_transcript_former.training.losses import ZINBLoss + + pi, mu, theta = zinb_tensors[0] + pi = pi.clone().requires_grad_(True) + mu = mu.clone().requires_grad_(True) + theta = theta.clone().requires_grad_(True) + + target = zinb_tensors[1] + loss = ZINBLoss()((pi, mu, theta), target) + loss.backward() + + assert pi.grad is not None + assert mu.grad is not None + assert theta.grad is not None + assert not torch.isnan(pi.grad).any() + assert not torch.isnan(mu.grad).any() + assert not torch.isnan(theta.grad).any() + + def test_mask_support(self): + """ZINBLoss should handle masks in 3D mode correctly.""" + from spatial_transcript_former.training.losses import ZINBLoss + + torch.manual_seed(42) + B, N, G = 2, 10, 5 + pi = torch.rand(B, N, G) + mu = torch.rand(B, N, G) * 5 + 0.1 + theta = torch.rand(B, N, G) * 5 + 0.1 + target = torch.randint(0, 10, (B, N, G)).float() + + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 5:] = True # sample 0 half padded + + pi.requires_grad_(True) + + loss_fn = ZINBLoss() + loss = loss_fn((pi, mu, theta), target, mask=mask) + loss.backward() + + assert loss.isfinite() + + # Gradients in padded regions should be zero + padded_grad = pi.grad[0, 5:, :] + assert padded_grad.abs().sum() == 0.0 + + +# --------------------------------------------------------------------------- +# AuxiliaryPathwayLoss +# --------------------------------------------------------------------------- + + +@pytest.fixture +def pathway_tensors(): + """Tensors for AuxiliaryPathwayLoss testing.""" + torch.manual_seed(42) + B, N, G, P = 2, 100, 50, 10 + gene_preds = torch.randn(B, N, G) + target_genes = torch.randn(B, N, G).abs() # Positive counts + pathway_preds = torch.randn(B, N, P) + # Binary MSigDB-like matrix: (P, G) + pathway_matrix = (torch.rand(P, G) > 0.8).float() + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 80:] = True + return gene_preds, target_genes, pathway_preds, pathway_matrix, mask + + +class TestAuxiliaryPathwayLoss: + def test_basic_computation(self, pathway_tensors): + """AuxiliaryPathwayLoss should produce a finite scalar.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=0.5) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + assert loss.isfinite() + + def test_includes_gene_loss(self, pathway_tensors): + """Total loss should be >= gene loss alone.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + gene_only = base(gene_preds, targets, mask=mask) + total = aux(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + # Pathway PCC loss is >= 0, so total >= gene_only + assert total.item() >= gene_only.item() - 1e-5 + + def test_gradient_flows_to_both(self, pathway_tensors): + """Gradients should flow to both gene_preds and pathway_preds.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + gene_preds = gene_preds.clone().requires_grad_(True) + pw_preds = pw_preds.clone().requires_grad_(True) + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=0.5) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + loss.backward() + assert gene_preds.grad is not None + assert pw_preds.grad is not None + assert gene_preds.grad.abs().sum() > 0 + assert pw_preds.grad.abs().sum() > 0 + + def test_fallback_without_pathways(self, pathway_tensors): + """When pathway_preds is None, should fall back to gene loss only.""" + gene_preds, targets, _, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + gene_only = base(gene_preds, targets, mask=mask) + fallback = aux(gene_preds, targets, mask=mask, pathway_preds=None) + assert torch.allclose(gene_only, fallback) + + def test_lambda_zero_disables(self, pathway_tensors): + """lambda_pathway=0 should produce the same result as gene loss.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.0) + gene_only = base(gene_preds, targets, mask=mask) + total = aux(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + assert torch.allclose(gene_only, total) + + def test_zinb_integration(self): + """AuxiliaryPathwayLoss should work with ZINB (pi, mu, theta) output.""" + from spatial_transcript_former.training.losses import ( + ZINBLoss, + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + torch.manual_seed(42) + B, N, G, P = 2, 10, 50, 10 + pi = torch.rand(B, N, G) + mu = torch.rand(B, N, G) * 5 + 1.0 + theta = torch.rand(B, N, G) * 5 + 1.0 + zinb_preds = (pi, mu, theta) + + targets = torch.randint(0, 10, (B, N, G)).float() + pw_preds = torch.randn(B, N, P) + pw_matrix = (torch.rand(P, G) > 0.8).float() + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 5:] = True + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, ZINBLoss(), lambda_pathway=1.0) + loss = loss_fn(zinb_preds, targets, mask=mask, pathway_preds=pw_preds) + + assert loss.isfinite() + assert loss.item() > 0 + + def test_perfect_match_zero_aux(self, pathway_tensors): + """When pathway preds perfectly correlate with truth, aux loss contribution is 0.""" + gene_preds, targets, _, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=1.0) + + # Compute ground truth pathways + with torch.no_grad(): + target_pathways = torch.matmul(targets, pw_matrix.T) + + gene_loss = base(gene_preds, targets, mask=mask) + # Use target_pathways as pathway_preds + total_loss = aux(gene_preds, targets, mask=mask, pathway_preds=target_pathways) + + # Total loss should roughly equal gene loss (since 1-PCC should be 0) + assert total_loss.item() == pytest.approx(gene_loss.item(), abs=1e-5) + + def test_numerical_stability_empty_spots(self, pathway_tensors): + """Test stability when targets for some pathways are all zero.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + + # Zero out targets for the first pathway + # Matrix is (P, G), so genes involved in pathway 0 + genes_in_p0 = pw_matrix[0].bool() + targets[..., genes_in_p0] = 0.0 + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=1.0) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + + assert loss.isfinite() + assert not torch.isnan(loss) + + def test_lambda_scaling(self, pathway_tensors): + """Doubling lambda_pathway should increase the aux contribution accordingly.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + + aux1 = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + aux2 = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=1.0) + + loss1 = aux1(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + loss2 = aux2(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + gene_loss = base(gene_preds, targets, mask=mask) + + term1 = loss1 - gene_loss + term2 = loss2 - gene_loss + + # term2 should be approx 2 * term1 + assert term2.item() == pytest.approx(2 * term1.item(), rel=1e-4) + + def test_hallmark_integration(self): + """Test with a real (though small) MSigDB Hallmark matrix.""" + from spatial_transcript_former.data.pathways import get_pathway_init + from spatial_transcript_former.training.losses import ( + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + # Subset of genes known to be in MSigDB Hallmarks + gene_list = ["TP53", "MYC", "VEGFA", "VIM", "CDH1", "SNAI1", "AXIN2", "MLH1"] + B, N, G = 2, 5, len(gene_list) + + # Get real Hallmark matrix (only for our 8 genes) + matrix, names = get_pathway_init(gene_list, verbose=False) + P = len(names) + + preds = torch.randn(B, N, G) + targets = torch.randn(B, N, G).abs() + pw_preds = torch.randn(B, N, P) + + loss_fn = AuxiliaryPathwayLoss(matrix, MaskedMSELoss()) + loss = loss_fn(preds, targets, pathway_preds=pw_preds) + + assert loss.isfinite() + assert len(names) > 0 + + def test_hallmark_signal_detection(self): + """Verify that gene patterns aligned with a hallmark reduce the aux loss.""" + from spatial_transcript_former.training.losses import ( + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + # Define 4 genes and 2 pathways + # P0: G0, G1 + # P1: G2, G3 + pw_matrix = torch.tensor( + [[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]], dtype=torch.float32 + ) + gene_list = ["G0", "G1", "G2", "G3"] + B, N, G = 1, 10, 4 + P = 2 + + # Case 1: Random targets + torch.manual_seed(42) + targets = torch.randn(B, N, G).abs() + # Predictions match targets for genes + gene_preds = targets.clone() + + # Pathway preds are random + pw_preds_random = torch.randn(B, N, P) + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=1.0) + loss_random = loss_fn(gene_preds, targets, pathway_preds=pw_preds_random) + + # Case 2: Pathway preds perfectly match truth (which is targets @ matrix.T) + pw_truth = torch.matmul(targets, pw_matrix.T) + loss_perfect = loss_fn(gene_preds, targets, pathway_preds=pw_truth) + + # Case 3: Gene expression is specifically high for P0, and pw_preds are high for P0 + # If the spatial correlation is high, aux loss should be low. + assert loss_perfect.item() < loss_random.item() diff --git a/tests/test_model_backbones.py b/tests/test_model_backbones.py index cb4fc79..090dbb3 100644 --- a/tests/test_model_backbones.py +++ b/tests/test_model_backbones.py @@ -38,9 +38,10 @@ def test_interaction_model_backbone(): num_genes=num_genes, backbone_name="resnet50", pretrained=False ) - # Test with raw image input + # Test with raw image input (single patch => S=1) x = torch.randn(4, 3, 224, 224) - out = model(x) + rel_coords = torch.randn(4, 1, 2) + out = model(x, rel_coords=rel_coords) assert out.shape == (4, num_genes) diff --git a/tests/test_models.py b/tests/test_models.py index e16f055..39b2b1c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,54 +1,30 @@ import pytest import torch from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.models.interaction import MultimodalFusion def test_interaction_output_shape(mock_image_batch): """ EDUCATIONAL: This test verifies that the SpatialTranscriptFormer correctly maps a batch of histology images to a high-dimensional gene expression vector. - - Architecture: - Input (x) -> Backbone -> Interaction (Cross-Attention) -> Gene Reconstructor -> Output (y) """ num_genes = 1000 model = SpatialTranscriptFormer(num_genes=num_genes) - output = model(mock_image_batch) - - # Verify shape: (Batch Size, Number of Genes) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_multimodal_fusion_quadrant_masking(): - """ - EDUCATIONAL: This test verifies Jaume Et Al. Equation 1 quadrant masking. - We verify that masking a quadrant (P2H: Pathway to Histology) effectively - blocks information flow, resulting in zero gradients for histology tokens. - - Matrix A = [ P2P P2H ] - [ H2P H2H ] - """ - dim, Np, Nh, B = 64, 5, 10, 2 - - # Mode: Mask P2H (Pathways cannot see Histology) - fusion = MultimodalFusion(dim, n_heads=4, n_layers=1, masked_quadrants=["P2H"]) - - p_tokens = torch.randn(B, Np, dim, requires_grad=True) - h_tokens = torch.randn(B, Nh, dim, requires_grad=True) + # Must provide rel_coords since use_spatial_pe defaults to True + B = mock_image_batch.shape[0] + if mock_image_batch.dim() == 5: + S = mock_image_batch.shape[1] + elif mock_image_batch.dim() == 4: + S = 1 + else: + S = mock_image_batch.shape[1] - # Forward pass: Results in Pathway tokens (Bottleneck) - out = fusion(p_tokens, h_tokens) + rel_coords = torch.randn(B, S, 2) + output = model(mock_image_batch, rel_coords=rel_coords) - # Backpropagate signal from output - out.sum().backward() - - # ASSERTION: If P2H is masked, histology tokens (H) should receive ZERO signal from the output - assert h_tokens.grad is not None - assert ( - h_tokens.grad.abs().sum() < 1e-6 - ), "Histology tokens should have no gradient if P2H quadrant is masked" + # Verify shape: (Batch Size, Number of Genes) + assert output.shape == (B, num_genes) def test_sparsity_regularization_loss(): @@ -66,21 +42,3 @@ def test_sparsity_regularization_loss(): # Expect a positive scalar (L1 norm of reconstruction weights) assert sparsity_loss > 0 assert sparsity_loss.dim() == 0 - - -def test_interaction_with_masking(mock_image_batch): - """ - EDUCATIONAL: Verifies that the unified interaction model works with - different quadrant masking configurations. - """ - num_genes = 100 - - # Test with H2H masking (Default) - model = SpatialTranscriptFormer(num_genes=num_genes, masked_quadrants=["H2H"]) - out = model(mock_image_batch) - assert out.shape == (mock_image_batch.shape[0], num_genes) - - # Test with all quadrants open - model_all = SpatialTranscriptFormer(num_genes=num_genes, masked_quadrants=[]) - out_all = model_all(mock_image_batch) - assert out_all.shape == (mock_image_batch.shape[0], num_genes) diff --git a/tests/test_neighborhood.py b/tests/test_neighborhood.py deleted file mode 100644 index eb363ce..0000000 --- a/tests/test_neighborhood.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import numpy as np -import sys -import os - -# Add src to path -sys.path.append(os.path.abspath("src")) - -from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.data.dataset import HEST_Dataset - - -def test_neighborhood_model_forward(): - B = 2 - S = 9 # 1 center + 8 neighbors - C, H, W = 3, 224, 224 - G = 1000 - - model = SpatialTranscriptFormer(num_genes=G) - - # Input sequence - x = torch.randn(B, S, C, H, W) - rel_coords = torch.randn(B, S, 2) - - # Forward pass - out = model(x, rel_coords=rel_coords) - - print(f"Neighborhood Input shape: {x.shape}") - print(f"Output shape: {out.shape}") - - assert out.shape == (B, G), f"Expected (B, G), got {out.shape}" - print("Neighborhood forward pass test passed!") - - -def test_spatial_masking(): - B = 1 - S = 4 - G = 100 - # Center at (0,0), neighbors at (10,0), (100,0), (1000,0) - rel_coords = torch.tensor( - [[[0, 0], [10, 0], [100, 0], [1000, 0]]], dtype=torch.float32 - ) - - # Set mask radius to 50 -> neighbors 2 and 3 should be masked - model = SpatialTranscriptFormer(num_genes=G, mask_radius=50) - - mask = model._generate_spatial_mask(rel_coords) - # Expected mask: [False, False, True, True] (True means ignore) - - print(f"Rel Coords: {rel_coords}") - print(f"Generated Mask: {mask}") - - assert mask[0, 0] == False - assert mask[0, 1] == False - assert mask[0, 2] == True - assert mask[0, 3] == True - - print("Spatial masking logic test passed!") - - -if __name__ == "__main__": - test_neighborhood_model_forward() - test_spatial_masking() diff --git a/tests/test_nystrom.py b/tests/test_nystrom.py deleted file mode 100644 index 6b9c92e..0000000 --- a/tests/test_nystrom.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import pytest -from spatial_transcript_former.models import SpatialTranscriptFormer - - -def test_nystrom_jaume_mode(mock_image_batch): - """ - EDUCATIONAL: Verifies that Nystrom Attention works in 'jaume' (early fusion) mode. - Nystrom provides linear complexity, allowing for multimodal interaction - without the quadratic memory wall of standard self-attention. - """ - num_genes = 100 - model = SpatialTranscriptFormer( - num_genes=num_genes, use_nystrom=True, num_landmarks=32 # Small for testing - ) - - output = model(mock_image_batch) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_nystrom_no_quadrant_mask_support(mock_image_batch): - """ - EDUCATIONAL: Nystrom Attention (linear complexity) does not support - standard 2D quadrant masking. This test verifies the model still - executes without crashing, but we acknowledge the mask is ignored. - """ - num_genes = 100 - model = SpatialTranscriptFormer( - num_genes=num_genes, - use_nystrom=True, - masked_quadrants=["H2H"], - num_landmarks=32, - ) - - output = model(mock_image_batch) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_nystrom_scalability(): - """ - EDUCATIONAL: This test verifies that Nystrom mode can handle very large - histology sequences that would typically crash a standard transformer. - """ - B, S, D = 1, 1024, 512 # Large neighborhood - num_genes = 10 - - # We'll use mock projection features instead of raw images to save memory in CI - # token_dim=512 is used in SpatialTranscriptFormer - features = torch.randn(B, S, 2048) # Backbone output dim - - model = SpatialTranscriptFormer( - num_genes=num_genes, use_nystrom=True, num_landmarks=128 - ) - - # Forward pass on a sequence of 1024 tokens - with torch.no_grad(): - output = model(features) - - assert output.shape == (B, num_genes) diff --git a/tests/test_quadrant_masking.py b/tests/test_quadrant_masking.py deleted file mode 100644 index e43bf1d..0000000 --- a/tests/test_quadrant_masking.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn -import sys -import os - -# Add src to path -sys.path.append(os.path.abspath("src")) - -from spatial_transcript_former.models.interaction import MultimodalFusion - - -def test_quadrant_masking_interactions(): - dim = 64 - n_heads = 4 - n_layers = 1 - Np = 5 - Nh = 10 - B = 2 - - # Mode 1: Full Attention - fusion_full = MultimodalFusion(dim, n_heads, n_layers) - - # Mode 2: Mask H2H - fusion_masked = MultimodalFusion(dim, n_heads, n_layers, masked_quadrants=["H2H"]) - - p_tokens = torch.randn(B, Np, dim, requires_grad=True) - h_tokens = torch.randn(B, Nh, dim, requires_grad=True) - - # Forward full - out_full = fusion_full(p_tokens, h_tokens) - loss_full = out_full.sum() - loss_full.backward() - - h_grad_full = h_tokens.grad.abs().sum() - print(f"H tokens grad (Full): {h_grad_full.item()}") - - # Reset grad - h_tokens.grad.zero_() - p_tokens.grad.zero_() - - # Forward masked H2H - # NOTE: Masking H2H in self-attention technically means H tokens don't attend to H tokens. - # But H tokens can still attend to P tokens (H2P) unless we mask that too. - # If we mask H2H, H tokens are purely transformed by their interaction with P tokens. - out_masked = fusion_masked(p_tokens, h_tokens) - loss_masked = out_masked.sum() - loss_masked.backward() - - h_grad_masked = h_tokens.grad.abs().sum() - print(f"H tokens grad (Masked H2H): {h_grad_masked.item()}") - - # Forward masked H2H AND H2P (H tokens isolated) - h_tokens.grad.zero_() - fusion_isolated = MultimodalFusion( - dim, n_heads, n_layers, masked_quadrants=["H2H", "H2P"] - ) - out_iso = fusion_isolated(p_tokens, h_tokens) - out_iso.sum().backward() - - h_grad_iso = h_tokens.grad.abs().sum() - print(f"H tokens grad (H2H + H2P Masked): {h_grad_iso.item()}") - - # In H2H + H2P masked mode, output depends on P tokens attending to H and P. - # H tokens shouldn't receive any gradients from the output if they are isolated - # and the output only contains the transformed P tokens. - # WAIT: P tokens STILL attend to H tokens (P2H). So H tokens WILL get gradients. - - # Let's verify P2H masking - h_tokens.grad.zero_() - fusion_no_hist = MultimodalFusion(dim, n_heads, n_layers, masked_quadrants=["P2H"]) - out_no_hist = fusion_no_hist(p_tokens, h_tokens) - out_no_hist.sum().backward() - h_grad_no_hist = h_tokens.grad.abs().sum() - print(f"H tokens grad (P2H Masked): {h_grad_no_hist.item()}") - - assert ( - h_grad_no_hist < 1e-5 - ), "H tokens should not receive gradients if P2H is masked" - print("Quadrant masking verification passed!") - - -if __name__ == "__main__": - test_quadrant_masking_interactions() diff --git a/tests/test_spatial.py b/tests/test_spatial.py deleted file mode 100644 index 5a5e094..0000000 --- a/tests/test_spatial.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest -import torch -from spatial_transcript_former.models import SpatialTranscriptFormer - - -def test_neighborhood_forward_pass(mock_neighborhood_batch, mock_rel_coords): - """ - EDUCATIONAL: This test verifies that the model can process a sequence of - histology patches (Neighborhood Mode). - - Input shape: (Batch, Sequence Length, Channels, Height, Width) - Output shape: (Batch, Number of Genes) - """ - num_genes = 1000 - model = SpatialTranscriptFormer(num_genes=num_genes) - - # Forward pass with neighborhood sequence and relative coordinates - output = model(mock_neighborhood_batch, rel_coords=mock_rel_coords) - - assert output.shape == (mock_neighborhood_batch.shape[0], num_genes) - - -def test_distance_based_spatial_masking(): - """ - EDUCATIONAL: This test verifies that the 'mask_radius' correctly suppresses - interactions with distant patches in the neighborhood. - - We place patches at different distances and verify that the generated - attention mask ignores those beyond the radius. - """ - G = 100 - # Relative Coords: Center(0,0), Near(10,0), Far(100,0), Very Far(1000,0) - rel_coords = torch.tensor( - [ - [ - [0, 0], # Index 0 (Center) - [10, 0], # Index 1 (Near) - [100, 0], # Index 2 (Far) - [1000, 0], # Index 3 (Very Far) - ] - ], - dtype=torch.float32, - ) - - # Set radius to 50: Only Center and Near should be visible - model = SpatialTranscriptFormer(num_genes=G, mask_radius=50) - - # Internal method creates a boolean mask (True = Ignore) - mask = model._generate_spatial_mask(rel_coords) - - # Expected: [False (Seen), False (Seen), True (Masked), True (Masked)] - assert mask[0, 0] == False - assert mask[0, 1] == False - assert mask[0, 2] == True - assert mask[0, 3] == True diff --git a/tests/test_spatial_alignment.py b/tests/test_spatial_alignment.py index f5ac0b0..611ec2e 100644 --- a/tests/test_spatial_alignment.py +++ b/tests/test_spatial_alignment.py @@ -5,28 +5,14 @@ def test_spatial_mixing_with_large_coordinates(): """ - Verifies that the model correctly normalizes large pixel coordinates (e.g. 256px steps) - so that local mixing can occur between adjacent patches. + Verifies that the model correctly handles large pixel coordinates (e.g. 256px steps) + and that gradient flows between patches via the pathway bottleneck. """ - # ResNet50 output dim is 2048, so we must use token_dim=2048 or project it. - # The model has an image_proj layer: Linear(image_feature_dim, token_dim). - # Wait, the error was 2x64 and 2048x64? - # Ah, the error "mat1 and mat2 shapes cannot be multiplied (2x64 and 2048x64)" - # suggests the input x has dim 64, but the first layer (backbone) expects something else? - # Actually, the model forward() expects images (B, C, H, W) OR features (B, N, D). - # If passing features, D must match image_feature_dim (2048 for ResNet). - - # Let's verify model source: - # if x.dim() == 3: features = x - # memory = self.image_proj(features) - # image_proj is Linear(2048, token_dim). - - # So input x mus be (B, N, 2048). dim = 2048 token_dim = 64 model = SpatialTranscriptFormer( - num_genes=10, token_dim=token_dim, n_layers=1, use_spatial_pe=True + num_genes=10, token_dim=token_dim, n_layers=2, use_spatial_pe=True ) # Create two patches that are physically adjacent (256px apart) but logically neighbors @@ -40,8 +26,8 @@ def test_spatial_mixing_with_large_coordinates(): output = model(x, rel_coords=coords) # Check gradient flow from Patch 1 to Patch 0's output - # If mixing works, output[0] should be influenced by input[1] via the LocalPatchMixer - # We sum output[0] and check grad w.r.t x[0, 1] + # With the pathway bottleneck, both patches attend to pathway tokens, so + # indirect gradient flow exists through the shared pathway embeddings. loss = output[0, 0].sum() loss.backward() @@ -49,8 +35,7 @@ def test_spatial_mixing_with_large_coordinates(): assert grad_neighbor > 0.0, ( f"Gradient from neighbor is zero ({grad_neighbor}). " - "The model failed to mix features between patches separated by 256px. " - "This indicates the LocalPatchMixer is treating them as distant neighbors." + "The model failed to propagate gradients between patches via pathway tokens." ) diff --git a/tests/test_spatial_interaction.py b/tests/test_spatial_interaction.py index 308e861..b887043 100644 --- a/tests/test_spatial_interaction.py +++ b/tests/test_spatial_interaction.py @@ -3,58 +3,163 @@ from unittest.mock import MagicMock from spatial_transcript_former.models.interaction import ( SpatialTranscriptFormer, - LocalPatchMixer, - GraphPatchMixer, + LearnedSpatialEncoder, + VALID_INTERACTIONS, ) from spatial_transcript_former.training.engine import train_one_epoch, validate -def test_normalize_coords_finds_correct_step(): - """ - Test that _normalize_coords correctly infers the grid spacing (e.g., 224) - even when absolute coordinates are very large, which breaks the current median - absolute value heuristic. - """ - model = SpatialTranscriptFormer(num_genes=10, use_nystrom=False) +def test_n_layers_enforcement(): + """n_layers < 2 with h2h blocked should raise ValueError.""" + with pytest.raises(ValueError, match="n_layers must be >= 2"): + SpatialTranscriptFormer( + num_genes=50, n_layers=1, interactions=["p2p", "p2h", "h2p"] + ) + - # Simulate a 2x2 grid of patches with patch size 224 in absolute pixel coords - coords = torch.tensor( - [[10000.0, 20000.0], [10224.0, 20000.0], [10000.0, 20224.0], [10224.0, 20224.0]] - ).unsqueeze( - 0 - ) # Output shape: (1, 4, 2) +def test_n_layers_ok_with_full_interactions(): + """n_layers=1 is allowed when h2h is enabled (full attention).""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=1, + interactions=["p2p", "p2h", "h2p", "h2h"], + ) + assert model is not None - normalized = model._normalize_coords(coords) - # LocalPatchMixer subtracts the minimum coordinate to create a zero-indexed grid. - min_c = normalized.min(dim=1, keepdim=True)[0] - grid_coords = normalized - min_c +def test_invalid_interaction_key(): + """Unknown interaction keys should raise ValueError.""" + with pytest.raises(ValueError, match="Unknown interaction keys"): + SpatialTranscriptFormer(num_genes=50, interactions=["p2p", "x2y"]) - expected_grid = torch.tensor( - [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]] - ).unsqueeze(0) - # The current heuristic fails this because it divides by ~15000 (median) - assert torch.allclose( - grid_coords, expected_grid - ), f"Coordinate normalization failed, returned:\n{grid_coords}" +@pytest.mark.parametrize( + "interactions", + [ + ["p2p", "p2h", "h2p", "h2h"], # full + ["p2p", "p2h", "h2p"], # bottleneck + ["p2p", "p2h"], # pathway-only + ], +) +def test_interaction_combinations(interactions): + """Various interaction combos should produce correct output shapes.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + interactions=interactions, + ) + B, N, D = 2, 5, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + out = model(features, rel_coords=coords) + assert out.shape == (B, 50) + + +def test_full_interactions_returns_no_mask(): + """When all interactions are enabled, _build_interaction_mask returns None.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + interactions=["p2p", "p2h", "h2p", "h2h"], + ) + mask = model._build_interaction_mask(p=10, s=20, device=torch.device("cpu")) + assert mask is None + +def test_missing_coords_raises(): + """use_spatial_pe=True without coords should raise ValueError.""" + model = SpatialTranscriptFormer(num_genes=50, token_dim=64, n_heads=4, n_layers=2) + features = torch.randn(2, 10, 2048) + with pytest.raises(ValueError, match="rel_coords was not provided"): + model(features) -def test_engine_passes_coords_to_dense_forward(): + +def test_spatial_transcript_former_dense_forward(): + """Instantiate the model and ensure forward() with return_dense=True executes properly.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + ) + + B, N, D = 2, 10, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + + # 1. Standard Forward Pass + out = model(features, rel_coords=coords) + assert out.shape == ( + B, + 50, + ), f"Expected global output shape {(B, 50)}, got {out.shape}" + + # Reset grads + model.zero_grad() + + # 2. Dense Forward Pass + out_dense = model.forward(features, rel_coords=coords, return_dense=True) + assert out_dense.shape == ( + B, + N, + 50, + ), f"Expected dense output shape {(B, N, 50)}, got {out_dense.shape}" + + out_dense.sum().backward() + assert ( + model.fusion_engine.layers[0].self_attn.in_proj_weight.grad is not None + ), "Gradients did not flow through the transformer in dense pass." + + +def test_unified_pathway_scoring(): + """Both global and dense modes should produce pathway_scores from dot-products.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + num_pathways=10, + ) + + B, N, D = 2, 5, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + + # Global mode returns (gene_expression, pathway_scores) + gene_expr, pw_scores = model(features, rel_coords=coords, return_pathways=True) + assert gene_expr.shape == (B, 50) + assert pw_scores.shape == (B, 10) # (B, P) from pooled dot-product + + # Dense mode returns (gene_expression, pathway_scores) + gene_expr_d, pw_scores_d = model( + features, rel_coords=coords, return_pathways=True, return_dense=True + ) + assert gene_expr_d.shape == (B, N, 50) + assert pw_scores_d.shape == (B, N, 10) # (B, S, P) from per-patch dot-product + + +def test_engine_passes_coords_to_forward(): """ - Verify that the training engine passes the generated spatial coordinates to - the model in whole_slide mode. + Verify that the training engine passes spatial coordinates to + model.forward() in whole_slide mode. """ model = SpatialTranscriptFormer(num_genes=10) - # Mock to track if coords are passed - model.forward_dense = MagicMock(return_value=torch.randn(2, 5, 10)) + # Mock forward to track calls + original_forward = model.forward + model.forward = MagicMock(return_value=torch.randn(2, 5, 10)) model.get_sparsity_loss = MagicMock(return_value=torch.tensor(0.0)) fake_coords = torch.randn(2, 5, 2) fake_mask = torch.zeros(2, 5).bool() - # Datloader yielding (feats, genes, coords, mask) + # Dataloader yielding (feats, genes, coords, mask) loader = [(torch.randn(2, 5, 512), torch.randn(2, 5, 10), fake_coords, fake_mask)] optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -64,12 +169,12 @@ def dummy_criterion(p, t, mask=None): train_one_epoch(model, loader, dummy_criterion, optimizer, "cpu", whole_slide=True) - model.forward_dense.assert_called_once() - kwargs = model.forward_dense.call_args.kwargs + model.forward.assert_called_once() + kwargs = model.forward.call_args.kwargs - assert "coords" in kwargs, "Engine did not pass 'coords' kwargs to forward_dense!" + assert "rel_coords" in kwargs, "Engine did not pass 'rel_coords' kwargs to forward!" assert torch.allclose( - kwargs["coords"], fake_coords + kwargs["rel_coords"], fake_coords ), "Engine passed wrong coordinate tensor!" @@ -78,7 +183,7 @@ def test_engine_validate_passes_coords(): Verify validation loop passes coords. """ model = SpatialTranscriptFormer(num_genes=10) - model.forward_dense = MagicMock(return_value=torch.randn(2, 5, 10)) + model.forward = MagicMock(return_value=torch.randn(2, 5, 10)) fake_coords = torch.randn(2, 5, 2) loader = [ @@ -95,79 +200,12 @@ def dummy_criterion(p, t, mask=None): validate(model, loader, dummy_criterion, "cpu", whole_slide=True) - model.forward_dense.assert_called_once() - kwargs = model.forward_dense.call_args.kwargs + model.forward.assert_called_once() + kwargs = model.forward.call_args.kwargs assert ( - "coords" in kwargs - ), "Validate engine did not pass 'coords' kwargs to forward_dense!" + "rel_coords" in kwargs + ), "Validate engine did not pass 'rel_coords' kwargs to forward!" assert torch.allclose( - kwargs["coords"], fake_coords + kwargs["rel_coords"], fake_coords ), "Validate engine passed wrong coordinate tensor!" - - -def test_graph_patch_mixer(): - """Verify that the GraphPatchMixer correctly performs message passing over a k-NN graph.""" - B, N, D = 2, 10, 32 - k = 3 - mixer = GraphPatchMixer(dim=D, k=k, heads=4) - - x = torch.randn(B, N, D) - coords = torch.randn(B, N, 2) - - # 1. Forward Pass - out = mixer(x, coords) - - # 2. Shape Verification - assert out.shape == (B, N, D), f"Expected shape {(B, N, D)}, got {out.shape}" - - # 3. Gradient Flow Verification - out.sum().backward() - assert ( - mixer.to_qkv.weight.grad is not None - ), "Gradients did not flow through the GAT layer." - - -def test_spatial_transcript_former_with_gnn_refiner(): - """Instantiate the model with GNN refiner and ensure forward() and forward_dense() execute properly.""" - model = SpatialTranscriptFormer( - num_genes=50, - token_dim=64, - n_heads=4, - n_layers=1, - early_mixer=None, - late_refiner="gnn", - ) - - B, N, D = 2, 10, 2048 - features = torch.randn(B, N, D) - coords = torch.randn(B, N, 2) - - # Check that early_mixer is None and late_refiner is initialized - assert model.early_mixer is None, "early_mixer should be None" - assert ( - hasattr(model, "late_refiner") and model.late_refiner is not None - ), "late_refiner not initialized" - - # 1. Standard Forward Pass - out = model(features, rel_coords=coords) - assert out.shape == ( - B, - 50, - ), f"Expected dense output shape {(B, 50)}, got {out.shape}" - - # Reset grads - model.zero_grad() - - # 2. Dense Forward Pass (This is the one that uses the Graph Refiner explicitly) - out_dense = model.forward_dense(features, coords=coords) - assert out_dense.shape == ( - B, - N, - 50, - ), f"Expected dense output shape {(B, N, 50)}, got {out_dense.shape}" - - out_dense.sum().backward() - assert ( - model.late_refiner.to_qkv.weight.grad is not None - ), "Gradients did not flow through the late_refiner in dense pass." diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 0000000..f94ead4 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,44 @@ +import pytest +import warnings + + +def function_that_warns(): + warnings.warn("This is a deprecated feature", DeprecationWarning) + return True + + +def function_that_warns_user(): + warnings.warn("FigureCanvasAgg is non-interactive", UserWarning) + return True + + +def test_demonstrate_warning_assertion(): + """ + Demonstrates how to assert that a specific warning is raised. + This is useful for ensuring your code warns users correctly. + """ + with pytest.warns(DeprecationWarning, match="deprecated feature"): + result = function_that_warns() + assert result is True + + +def test_global_filter_demonstration(): + """ + This test will pass without showing warnings in the output + because we added filters to pyproject.toml. + + Specifically 'FigureCanvasAgg is non-interactive' is filtered. + """ + result = function_that_warns_user() + assert result is True + + +def test_how_to_catch_and_ignore_locally(): + """ + If you want to ignore a warning locally in a specific test + without adding it to the global pyproject.toml. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = function_that_warns() + assert result is True