From fb7a4cfa327c10e1dcc80cfb2ea7fbc0bfaa7f3f Mon Sep 17 00:00:00 2001 From: BenjaminIsaac0111 <12176376+BenjaminIsaac0111@users.noreply.github.com> Date: Tue, 24 Feb 2026 17:02:34 +0000 Subject: [PATCH 1/4] feat: implement quad-flow interaction system and major architectural cleanup Implemented a flexible attention masking system (Quad-Flow) that allows granular configuration of interaction quadrants (p2p, p2h, h2p, h2h), now defaulting to full interaction mode. This refactor simplifies the core architecture while introducing robust biological supervision and scalable presets. Key Changes: - Model Architecture: * Replaced binary masking with a composable '--interactions' list. * Integrated AuxiliaryPathwayLoss for direct biological supervision and ZINB loss support for raw count modeling. * Simplified core logic by inlining 'PathwayTokenizer' and removing legacy 5D/2D branches, dead aliases, and deprecated forward paths. * Added 'return_attention' support for cross-layer attention map extraction. - Diagnostics & Scalability: * Added 'diagnose_collapse.py' to monitor pathway diversity and analyze mean attention across the four interaction quadrants. * Refactored 'run_preset.py' into a unified flag system with new scaled model variants (L2, L4, L6). * Standardized environment setup (PS1/SH) on Python 3.9 for CI consistency. - Documentation & Project: * Rewrote 'MODELS.md' to detail the Quad-Flow Interaction design and auxiliary loss branch. * Updated 'README.md' and 'TRAINING_GUIDE.md' with ZINB recipes, training presets, and feature highlights. * Introduced 'CONTRIBUTING.md' to define IP rules and dev standards. - Verification: * Expanded test suite to verify interaction logic, ZINB/Auxiliary loss stability, and MSigDB signal detection. --- CONTRIBUTING.md | 60 ++ README.md | 69 +- config.yaml | 5 +- docs/MODELS.md | 339 ++++--- docs/TRAINING_GUIDE.md | 41 +- scripts/diagnose_collapse.py | 414 +++++++++ scripts/monitor.py | 139 +-- scripts/run_preset.py | 150 ++- setup.ps1 | 5 +- setup.sh | 6 +- .../models/interaction.py | 874 +++++------------- src/spatial_transcript_former/predict.py | 41 +- src/spatial_transcript_former/train.py | 114 ++- .../training/engine.py | 104 ++- .../training/losses.py | 218 ++++- .../visualization.py | 120 ++- tests/test_bottleneck_arch.py | 7 +- tests/test_checkpoint.py | 95 +- tests/test_interactions.py | 176 ++++ tests/test_local_patch_mixer.py | 92 -- tests/test_losses.py | 295 +++++- tests/test_model_backbones.py | 5 +- tests/test_models.py | 66 +- tests/test_neighborhood.py | 27 - tests/test_nystrom.py | 59 -- tests/test_quadrant_masking.py | 83 -- tests/test_spatial.py | 35 - tests/test_spatial_alignment.py | 27 +- tests/test_spatial_interaction.py | 256 ++--- 29 files changed, 2273 insertions(+), 1649 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 scripts/diagnose_collapse.py create mode 100644 tests/test_interactions.py delete mode 100644 tests/test_local_patch_mixer.py delete mode 100644 tests/test_nystrom.py delete mode 100644 tests/test_quadrant_masking.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bd90cd8 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,60 @@ +# Contributing to SpatialTranscriptFormer + +Thank you for your interest in contributing! As a project at the intersection of deep learning and pathology, we value rigorous, well-tested contributions. + +## Project Status + +> [!IMPORTANT] +> This project is a **Work in Progress**. We are actively refining the core interaction logic and scaling behaviors. Expect breaking changes in the CLI and data schemas. + +## Intellectual Property & Licensing + +SpatialTranscriptFormer is protected under a **Proprietary Source Code License**. + +- **Academic/Non-Profit**: We encourage contributions from the research community. Contributions made under an academic affiliation are generally welcome. +- **Commercial/For-Profit**: Contributions from commercial entities or individuals intended for profit-seeking use require a separate agreement. +- **Assignment**: By submitting a Pull Request, you agree that your contributions will be licensed under the project's existing license, granting the author the right to include them in both the open-access and proprietary versions of the software. + +## Development Workflow + +### 1. Environment Setup + +Use the provided setup scripts to ensure a consistent development environment: + +```bash +# Windows +.\setup.ps1 + +# Linux/HPC +bash setup.sh +``` + +### 2. Coding Standards + +We use `black` for formatting and `flake8` for linting. Please ensure your code passes these checks before submitting. + +```bash +black . +flake8 src/ +``` + +### 3. Testing + +All new features must include unit tests in the `tests/` directory. We use `pytest` for our test suite. + +```bash +# Run all tests +.\test.ps1 # Windows +bash test.sh # Linux +``` + +## Pull Request Process + +1. **Open an Issue**: For major changes, please open an issue first to discuss the design. +2. **Branching**: Work on a descriptive feature branch (e.g., `feature/pathway-attention-mask`). +3. **Documentation**: Update relevant files in `docs/` and the `README.md` if your change affects usage. +4. **Verification**: Ensure all CI checks (GitHub Actions) pass. + +## Contact + +For questions regarding commercial licensing or complex architectural changes, please contact the author directly. diff --git a/README.md b/README.md index c2eada9..f7884ff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ # SpatialTranscriptFormer -A transformer-based model for spatial transcriptomics. +> [!WARNING] +> **Work in Progress**: This project is under active development. Core architectures, CLI flags, and data formats are subject to major changes. + +A transformer-based model for spatial transcriptomics that bridges histology and biological pathways. + +## Key Features + +- **Quad-Flow Interaction**: Configurable attention between Pathways and Histology patches (`p2p`, `p2h`, `h2p`, `h2h`). +- **Pathway Bottleneck**: Interpretable gene expression prediction via 50 MSigDB Hallmark tokens. +- **Spatial Pattern Coherence**: Optimized using a composite **MSE + PCC (Pearson Correlation) loss** to prevent spatial collapse and ensure accurate morphology-expression mapping. +- **Biologically Informed Initialization**: Gene reconstruction weights derived from known hallmark memberships. ## License @@ -25,71 +35,66 @@ This project requires [Conda](https://docs.conda.io/en/latest/). ## Usage -After installation, the following command-line tools are available in your `SpatialTranscriptFormer` environment: - ### Download HEST Data Download specific subsets using filters or patterns: ```bash -# List available organs -stf-download --list_organs - # Download only the Bowel Cancer subset (including ST data and WSIs) stf-download --organ Bowel --disease Cancer --local_dir hest_data - -# Download any other organ -stf-download --organ Kidney ``` -### Split Dataset +### Train Models + +We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer. -Perform patient-stratified splitting on the metadata: +```bash +# Recommended: Run the Interaction model with 4 transformer layers +python scripts/run_preset.py --preset stf_interaction_l4 -```powershell -stf-split HEST_v1_3_0.csv --val_ratio 0.2 +# Run the lightweight 2-layer version +python scripts/run_preset.py --preset stf_interaction_l2 + +# Run baselines +python scripts/run_preset.py --preset he2rna_baseline ``` -### Train Models +For a complete list of configurations, see the [Training Guide](docs/TRAINING_GUIDE.md). -Train baseline models (HE2RNA, ViT) or the proposed interaction model. For a complete list of configurations and examples, see the [Training Guide](docs/TRAINING_GUIDE.md). +### Real-Time Monitoring -```bash -# Option 1: Using the standard command -stf-train --data-dir A:\hest_data --model he2rna --epochs 20 +Monitor training progress, loss curves, and **prediction variance (collapse detector)** via the web dashboard: -# Option 2: Using the preset launcher (recommended for complex models) -python scripts/run_preset.py --preset stf_interaction --epochs 30 +```bash +python scripts/monitor.py --run-dir runs/stf_interaction_l4 ``` ### Inference & Visualization -Generate spatial maps comparing Ground Truth vs Predictions for specific samples: +Generate spatial maps comparing Ground Truth vs Predictions: ```bash -stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model_he2rna.pth --model-type he2rna +stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model.pth --model-type interaction ``` Visualization plots will be saved to the `./results` directory. ## Documentation -For detailed information on the data and code implementation, see: - +- [Models](docs/MODELS.md): Detailed model architectures and scaling parameters. - [Data Structure](docs/DATA_STRUCTURE.md): Organization of HEST data on disk. -- [Dataloader](docs/DATALOADER.md): Technical implementation of the PyTorch dataset and loaders. -- [Gene Analysis](docs/GENE_ANALYSIS.md): Analysis of available genes and modeling strategies. -- [Pathway Mapping](docs/PATHWAY_MAPPING.md): Strategies for clinical interpretability and pathway integration. -- [Latent Discovery](docs/LATENT_DISCOVERY.md): Unsupervised discovery of biological pathways from data. -- [Models](docs/MODELS.md): Model architectures and literature references. +- [Pathway Mapping](docs/PATHWAY_MAPPING.md): Clinical interpretability and pathway integration. +- [Gene Analysis](docs/GENE_ANALYSIS.md): Modeling strategies for high-dimensional gene space. ## Development ### Running Tests -Use the included test wrapper: - ```bash -# Run all tests +# Run all tests (Pytest wrapper) .\test.ps1 ``` + +## Contributing + +We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on our coding standards and the process for submitting pull requests. Note that this project is under a proprietary license; contributions involve an assignment of rights for non-academic use. diff --git a/config.yaml b/config.yaml index 19384fd..ae788e3 100644 --- a/config.yaml +++ b/config.yaml @@ -4,15 +4,12 @@ # Data Paths # Candidates for the HEST data directory (checked in order) data_dirs: - - "hest_data" - - "../hest_data" - - "./data" - "A:\\hest_data" # Training Defaults training: num_genes: 1000 - batch_size: 32 + batch_size: 8 learning_rate: 0.0001 output_dir: "./checkpoints" diff --git a/docs/MODELS.md b/docs/MODELS.md index a7e7648..0b9dc54 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -1,194 +1,249 @@ -# Model Zoo and Literature References +# SpatialTranscriptFormer: Architecture & Design -This document provides a summary of the models implemented in this project, their origins in literature, and the architectural details of the **SpatialTranscriptFormer**. - -## 1. Regression Models (Patch-Level) - -### HE2RNA (ResNet-50) - -- **Reference**: [Schmauch et al. (2020). "A deep learning model to predict RNA-Seq expression of tumours from whole slide images." Nature Communications.](https://www.nature.com/articles/s41467-020-17679-z) -- **Description**: Uses a ResNet-50 backbone to extract features from histology patches and directly regresses a high-dimensional gene expression vector. -- **Equation**: - $$\hat{y} = \text{FC}(\text{ResNet50}(x))$$ - Where $x$ is the histology patch and $\hat{y} \in \mathbb{R}^G$ is the predicted expression for $G$ genes. - -### ViT-ST - -- **Description**: An adaptation of the Vision Transformer (ViT) to the spatial transcriptomics task by replacing the classification head with a regression head. -- **Equation**: - $$\hat{y} = \text{MLP}(\text{TransformerEncoder}(x_{tokens}))$$ +This document describes the architecture, design philosophy, and training objectives of the **SpatialTranscriptFormer** model, along with the baseline models implemented for comparison. --- -## 2. Multiple Instance Learning (Slide-Level) - -### Attention-MIL - -- **Reference**: [Ilse et al. (2018). "Attention-based Deep Multiple Instance Learning." ICML.](https://arxiv.org/abs/1802.04712) -- **Description**: Learns attention weights for individual patches to aggregate them into a slide-level representation. -- **Aggregation**: - $$z = \sum_{i=1}^N a_i h_i, \quad a_i = \frac{\exp(\mathbf{w}^\top \tanh(\mathbf{V} h_i^\top) \odot \text{sigm}(\mathbf{U} h_i^\top))}{\sum_{j=1}^N \exp(\mathbf{w}^\top \tanh(\mathbf{V} h_j^\top) \odot \text{sigm}(\mathbf{U} h_j^\top))}$$ - -### TransMIL - -- **Reference**: [Shao et al. (2021). "TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification." NeurIPS.](https://arxiv.org/abs/2106.00908) -- **Description**: Uses a Nyström-based linear transformer to capture correlations between patches across the entire slide. - -### Weak Supervision for MIL - -MIL models are trained with **bag-level supervision**: the model predicts a single slide-level expression vector and is supervised against the mean expression across all valid spots. - -Given a slide with $N$ valid spots and gene expression matrix $\mathbf{Y} \in \mathbb{R}^{N \times G}$, the **bag-level target** is: -$$\bar{y}_g = \frac{1}{N} \sum_{i=1}^N y_{ig}$$ - -For padded batches with mask $\mathbf{m} \in \{0, 1\}^{N}$, this becomes: -$$\bar{y}_g = \frac{\sum_{i=1}^N (1 - m_i) \cdot y_{ig}}{\sum_{i=1}^N (1 - m_i)}$$ - -The training loss is then: -$$\mathcal{L}_{weak} = \frac{1}{G} \sum_{g=1}^G (\hat{y}_g - \bar{y}_g)^2$$ +## 1. Problem Statement -### Spatial Attention Correlation +**Goal**: Predict spatially-resolved gene expression from histology images. -To evaluate whether MIL attention maps recover spatial localisation of gene activity, we compute the **Pearson correlation** between attention weights and total gene expression at each spot: -$$\rho = \text{corr}\left(\mathbf{a},\ \sum_{g=1}^G \mathbf{y}_g\right)$$ +**Data**: Spatial transcriptomics datasets (a subset of HEST-1k, filtered to bowel cancer from human patients) where each tissue section has: -where $\mathbf{a} \in \mathbb{R}^N$ is the vector of attention weights and $\sum_g \mathbf{y}_g \in \mathbb{R}^N$ is the total expression per spot. A high $\rho$ indicates the model has learned to attend to transcriptionally active regions. +- A whole-slide histology image (H&E) +- Per-spot gene expression counts with spatial coordinates -### Dense Supervision (Masked MSE) - -For models that support per-spot dense prediction in whole-slide mode (e.g., SpatialTranscriptFormer via `forward_dense`), we use a **masked MSE** that ignores padded positions: -$$\mathcal{L}_{dense} = \frac{\sum_{i=1}^N \sum_{g=1}^G (1 - m_i)(\hat{y}_{ig} - y_{ig})^2}{\sum_{i=1}^N (1 - m_i) \cdot G}$$ +**Challenge**: Directly predicting ~1000 genes from image patches is high-dimensional, noisy, and biologically uninterpretable. We need a structured bottleneck that compresses the gene space into biologically meaningful abstractions. --- -## 3. SpatialTranscriptFormer (Proposed Model) - -The **SpatialTranscriptFormer** (formerly Pathway Interaction Model) introduces a biologically-informed bottleneck layer using Cross-Attention between Learned Pathways and Image Features. - -### 3.1 Architecture Equations - -#### 1. Image Encoding - -Given an image patch $x$, we extract features $F$ using a backbone (e.g., ResNet or ViT): -$$F = \text{Encoder}(x), \quad F \in \mathbb{R}^{D}$$ +## 2. SpatialTranscriptFormer (Proposed Model) + +### 2.1 Design Philosophy + +The SpatialTranscriptFormer models the **interaction between biological pathways and histology** via four configurable information flows: + +1. **P↔P (Pathway self-interaction)**: Pathways refine each other's representations, capturing biological co-dependencies — e.g., EMT and inflammatory response pathways often co-activate in tumour microenvironments. + +2. **P→H (Pathway queries Histology)**: Pathway tokens query patch features with cross-attention, asking *"does this tissue region show morphological evidence of this biological process?"* — e.g., does a patch look consistent with angiogenesis or epithelial-mesenchymal transition? + +3. **H→P (Histology reads Pathways)**: Patch tokens attend to pathway tokens, receiving biological context — e.g., *"this patch is in a region where the inflammatory response pathway is highly active."* This contextualises the visual features with global biological state. + +4. **H↔H (Patch self-interaction)**: Patches attend to each other, enabling the model to capture spatial relationships between tissue regions directly. + +By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants: + +```bash +# Default: Full Interaction (all quadrants enabled) +--interactions p2p p2h h2p h2h + +# Pathway Bottleneck: block H↔H to force all inter-patch +# communication through the pathway bottleneck +--interactions p2p p2h h2p +``` + +> [!TIP] +> The **Pathway Bottleneck** variant (disabling `h2h`) is particularly useful for **interpretability** — all spatial interactions are mediated by named biological pathways — and for **anti-collapse** — preventing patches from averaging into identical representations. + +Three additional design principles support these interactions: + +- **Frozen Foundation Model Backbone** — The visual backbone (CTransPath, Phikon, etc.) is a pre-trained pathology feature extractor. It is never fine-tuned. The model learns only the pathway-histology interactions, keeping training lightweight. + +- **Dense Spatial Supervision** — Unlike weak MIL (which uses slide-level labels), we supervise at the **spot level** using spatial transcriptomics. Every patch receives ground-truth expression, enabling the model to learn spatially-resolved pathway activation patterns. + +- **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training. + +### 2.2 Spatial Learning + +The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this: + +1. **Positional Encoding** — Each patch token receives a 2D sinusoidal encoding of its (x, y) coordinate on the tissue. This means the pathway tokens, when they attend to patches, can distinguish *where* each patch is. A pathway token can learn that EMT is localised at the tumour-stroma boundary, not uniformly across the slide. + +2. **PCC Loss (Spatial Pattern Coherence)** — The Pearson Correlation component in the composite loss measures whether the *spatial pattern* of each gene's predicted expression matches the ground truth pattern, independently of scale. A model that predicts the same value everywhere scores PCC = 0, even if the mean is correct. This directly penalises spatial collapse. + +Together, these ensure the model learns *spatially-varying* pathway activation maps rather than slide-level averages. + +### 2.3 Architecture + +```text +┌──────────────────────────────────────────────────────────────────────────────┐ +│ SpatialTranscriptFormer │ +│ │ +│ ┌─────────────┐ ┌──────────────┐ ┌──────────────────────────┐ │ +│ │ Frozen │ │ Image │ │ + Spatial PE │ │ +│ │ Backbone │──>│ Projection │──>│ (2D Learned) │ │ +│ │ (CTransPath)│ │ (Linear) │ │ │ │ +│ └─────────────┘ └──────────────┘ └────────┬─────────────────┘ │ +│ │ Patch Tokens (S, D) │ +│ ┌──────────────────────────┐ │ │ +│ │ Learnable Pathway │ │ │ +│ │ Tokens (P, D) │────────┐ │ │ +│ │ (MSigDB Hallmarks) │ │ │ │ +│ └──────────────────────────┘ ▼ ▼ │ +│ ┌─────────────────────────────┐ │ +│ │ Transformer Encoder │ │ +│ │ Sequence: [Pathways|Patches]│ │ +│ │ │ │ +│ │ Full Interaction (default): │ │ +│ │ • P↔P ✅ P→H ✅ │ │ +│ │ • H→P ✅ H↔H ✅ │ │ +│ │ │ │ +│ │ Configurable via │ │ +│ │ --interactions flag │ │ +│ └──────────┬──────────────────┘ │ +│ │ │ +│ ┌──────────▼──────────────────┐ │ +│ │ Cosine Similarity Scoring │ │ +│ │ with Learnable Temperature │ │ +│ │ │ │ +│ │ scores = cos(patch, pathway)│ │ +│ │ × τ │ │ +│ └──────────┬──────────────────┘ │ +│ │ Pathway Scores (S, P) │ +│ ┌───────────┴───────────┐ │ +│ │ │ │ +│ ▼ ▼ │ +│ ┌──────────────────────┐ ┌───────────────────────────┐ │ +│ │ Gene Reconstructor │ │ Auxiliary Pathway Loss │ │ +│ │ (Linear: P → G) │ │ PCC(scores, target_pw) │ │ +│ │ Init: MSigDB │ │ weighted by λ_aux │ │ +│ └──────────┬───────────┘ └───────────┬───────────────┘ │ +│ │ │ │ +│ ▼ ▼ │ +│ Gene Expression (S, G) ℒ_aux = λ(1 − PCC) │ +│ │ │ │ +│ └──────────┬───────────────┘ │ +│ ▼ │ +│ ℒ_total = ℒ_gene + ℒ_aux │ +└──────────────────────────────────────────────────────────────────────────────┘ +``` + +### 2.3 Key Components + +#### Frozen Backbone (Feature Extraction) + +Pre-computed features from a pathology foundation model. The backbone is never fine-tuned. + +| Backbone | Feature Dim | Source | +| :--- | :--- | :--- | +| **CTransPath** | 768 | [Wang et al. (2022)](https://arxiv.org/abs/2111.13324) | +| **GigaPath** | 1536 | [Microsoft Prov-GigaPath](https://hf.co/prov-gigapath/prov-gigapath) | +| **Hibou** | 768 / 1024 | [HistAI Hibou](https://hf.co/histai) | +| **Phikon** | 768 | [Owkin Phikon](https://hf.co/owkin/phikon) | +| **ResNet-50** | 2048 | Torchvision (ImageNet) | -#### 2. Pathway Tokenization +#### Pathway Tokenizer -We initialize $K$ learnable pathway tokens $T_{path} \in \mathbb{R}^K \times D_{token}$. These tokens act as "queries" to search for relevant morphological features in the image. +Learnable embeddings $T \in \mathbb{R}^{P \times D}$ representing biological pathways. These act as [CLS]-like bottleneck tokens, analogous to the "perceiver" cross-attention pattern. -#### 3. Interaction (Unified Early Fusion) +#### Sinusoidal Positional Encoding (2D) -The interaction logic is unified via the **EarlyFusionBlock**, which facilitates a generalized attention mechanism between Pathway tokens ($P$) and Histology tokens ($H$). The model operates on a concatenated sequence $X = [P; H]$. +Patch tokens receive spatial location information via 2D sinusoidal embeddings: +$$H_{PE} = H_{proj} + \text{PE}_{2D}(x, y)$$ +This encodes the absolute position of each patch on the tissue slide, enabling the model to learn spatially-varying pathway activation patterns. -##### Quadrant Masking +#### Configurable Interaction Masking -To control the topology of the interaction and ensure scalability, we apply a mask $M$ to the attention matrix: +The transformer uses a custom attention mask that controls information flow between token groups. By default, **all interactions are enabled** (Full Interaction). You can selectively disable any combination using the `--interactions` flag: -- **P2P (Pathway Self-Interaction)**: Pathways refine each other's latent signatures. -- **P2H (Cross-Attention)**: Pathways query the histology patches for morphological evidence. -- **H2P (Biological Feedback)**: Histology features can be influenced by pathway context. -- **H2H (Global Patch-to-Patch)**: Patches attend to other distal patches. +| Quadrant | Token Flow | Description | +| :--- | :--- | :--- | +| **p2p** | Pathway ↔ Pathway | Pathways refine each other | +| **p2h** | Pathway → Histology | Pathways gather spatial info from patches | +| **h2p** | Histology → Pathway | Patches receive contextualised pathway signals | +| **h2h** | Histology ↔ Histology | Patches attend to each other directly | -> [!IMPORTANT] -> **Default Configuration**: In the standard experiment, **H2H is masked**. This removes the $O(N_H^2)$ memory bottleneck, allowing the model to handle thousands of spots simultaneously while still capturing global pathway-level correlations and local morphology. +> [!NOTE] +> Disabling `h2h` creates the **Pathway Bottleneck** variant, where all inter-patch communication must flow through the pathway tokens. This requires **minimum 2 transformer layers**: Layer 1 lets pathways gather information from patches, and Layer 2 lets patches read the contextualised pathway tokens. -### 3.2 Spatial Inductive Biases +#### Cosine Similarity Scoring -#### Sinusoidal Positional Encoding (2D) +Pathway scores are computed via L2-normalized cosine similarity with a learnable temperature parameter $\tau$ (following CLIP): +$$s_{ij} = \cos(\hat{h}_i, \hat{p}_j) \times \tau$$ +where $\hat{h}_i$ and $\hat{p}_j$ are the L2-normalized processed patch and pathway tokens respectively. This produces scores in $[-\tau, +\tau]$ with meaningful relative differences, avoiding the saturation that occurs with raw dot-products. -We inject knowledge of absolute patch locations $(x, y)$ into the transformer latent space via 2D sinusoidal embeddings: -$$H_{PE} = H_{proj} + \text{PE}_{2D}(x, y)$$ -This ensures the attention mechanism is aware of the relative distances and orientations between histology features. - -#### Local Patch Mixing (Conv) +#### Gene Reconstruction (Biologically-Informed) -Because global `H2H` is typically masked for efficiency, we use a **LocalPatchMixer** to regain immediate spatial context. This module applies a depthwise convolution over the spatial grid *before* global interaction: -$$H_{mixed} = \text{GELU}(\text{DepthwiseConv2d}(H_{grid}))$$ -This allows the model to capture high-bandwidth local morphology (e.g., 3x3 window) while the Transformer focuses on global biology-driven interactions. +A linear layer $W_{recon} \in \mathbb{R}^{G \times P}$ maps pathway scores to gene expression: +$$\hat{y}_g = \sum_{k=1}^P s_k \cdot W_{gk} + b_g$$ -### 3.3 Whole-Slide Dense Prediction (`forward_dense`) +When `--pathway-init` is enabled, $W_{recon}$ is initialised from the MSigDB Hallmark gene sets as a binary membership matrix, giving the model a biologically-grounded starting point where each pathway is connected only to its known member genes. -When predicting gene expression for every spot across a slide, the model leverages its global context differently. Instead of taking a single coordinate as the 'target', it updates all histology tokens simultaneously: +### 2.4 Training Modes -1. **Global Context**: Pathways query the entire slide (or a large masked window). -2. **Dense Head**: Updated histology tokens $H'$ are projected back to gene space: - $$\hat{y}_{dense} = H' \cdot T_{path}^\top \cdot W_{recon}$$ - This ensures that the predicted expression logic is consistent with the pathway bottleneck used during patch-level training. +| Mode | Input | Output | Supervision | +| :--- | :--- | :--- | :--- | +| **Dense (whole-slide)** | All patches from a slide | Per-patch gene predictions $(B, S, G)$ | Masked MSE+PCC at each spot | +| **Global** | All patches from a slide | Slide-level prediction $(B, G)$ | Mean-pooled expression | --- -## 4. Scalability: Nyström Approximation - -To scale effectively to whole-slide contexts or extremely large neighborhoods ($N_H > 1000$), we utilize the **Nyström method** for approximating the self-attention matrix. This reduces the complexity from quadratic $O(N^2)$ to linear $O(N \cdot m)$. - -### Theoretical Intuition +## 3. Baseline Models -The core of Transformer attention is the kernel matrix $\mathbf{K} \in \mathbb{R}^{N \times N}$. Nyström approximates $\mathbf{K}$ using a small set of **landmark points** $m \ll N$: -$$\mathbf{K} \approx \mathbf{C} \mathbf{W}^{+} \mathbf{C}^\top$$ -Where: +### HE2RNA (ResNet-50) -- $\mathbf{C} \in \mathbb{R}^{N \times m}$ contains $m$ landmark columns. -- $\mathbf{W} \in \mathbb{R}^{m \times m}$ is the intersection matrix of those columns. -- $\mathbf{W}^{+}$ is the Moore-Penrose pseudo-inverse. +- **Reference**: [Schmauch et al. (2020), Nature Communications](https://www.nature.com/articles/s41467-020-17679-z) +- Direct regression from patch features to gene expression via a single linear layer. -### Decomposition in Attention +### Attention-MIL -In the context of attention $\text{Softmax}(\frac{QK^\top}{\sqrt{d}})V$, the Nyström approximation allows us to compute the interaction without ever forming the $N \times N$ matrix: -$$\tilde{A} = \text{Softmax}\left(\frac{Q \tilde{K}^\top}{\sqrt{d}}\right) \left[ \text{Softmax}\left(\frac{\tilde{K} \tilde{K}^\top}{\sqrt{d}}\right) \right]^{+} \text{Softmax}\left(\frac{\tilde{K} K^\top}{\sqrt{d}}\right)$$ -Where $\tilde{K}$ are the pooled landmark features. +- **Reference**: [Ilse et al. (2018), ICML](https://arxiv.org/abs/1802.04712) +- Learns gated attention weights to aggregate patches into a slide-level representation. -### Why it matters for Spatial Transcriptomics +### TransMIL -1. **Whole Slide Context**: Standard transformers fail on WSIs (e.g., 20,000 patches). Nyström enables slide-level correlation (implemented in the **TransMIL** model). -2. **Dense Neighborhoods**: Allows the **SpatialTranscriptFormer** to consider hundreds of surrounding patches for a single prediction with minimal GPU memory overhead. +- **Reference**: [Shao et al. (2021), NeurIPS](https://arxiv.org/abs/2106.00908) +- Nyström-based transformer for capturing long-range patch correlations. --- -## 5. Backbone Zoo +## 4. Loss Functions -The model supports multiple state-of-the-art pathology backbones. These are selected via the CLI using the `--backbone` flag. +| `mse` | Masked MSE | Magnitude accuracy at each spot | +| `pcc` | 1 − PCC | Spatial pattern coherence per gene (scale-invariant) | +| `mse_pcc` | MSE + α(1 − PCC) | Balances absolute magnitude and spatial shape | +| `zinb` | ZINB NLL | Zero-Inflated Negative Binomial negative log-likelihood | -| Backbone | Variant | Feature Dim | Source / Reference | -| :--- | :--- | :--- | :--- | -| **ResNet** | `resnet50` | 2048 | Torchvision (ImageNet) | -| **CTransPath** | `ctranspath` | 768 | [Wang et al. (2022)](https://arxiv.org/abs/2111.13324) | -| **GigaPath** | `gigapath` | 1536 | [Microsoft Prov-GigaPath](https://hf.co/prov-gigapath/prov-gigapath) | -| **Hibou** | `hibou-b/l` | 768 / 1024 | [HistAI Hibou](https://hf.co/histai) | -| **Phikon** | `phikon` | 768 | [Owkin Phikon](https://hf.co/owkin/phikon) | -| **PLIP** | `plip` | 512 | [Huang et al. (2023)](https://hf.co/vinid/plip) | +### ZINB Loss -> [!NOTE] -> GigaPath and Hibou are **gated models**. You must accept the terms of use on their respective HuggingFace model pages before the code can download the weights. +The Zero-Inflated Negative Binomial (ZINB) loss is designed for raw, highly dispersed count data. It models the data using three parameters: ---- +- **$\pi$ (pi)**: Probability of zero-inflation (technical dropout). +- **$\mu$ (mu)**: Mean of the negative binomial distribution. +- **$\theta$ (theta)**: Inverse dispersion (clumping) parameter. -## 6. Biologically-Informed Bottleneck +The model outputs these parameters, and the loss computes the negative log-likelihood of the ground truth counts given this distribution. -The final gene expression $\hat{y}$ is a linear combination of pathway activations $s_k$: -$$\hat{y}_g = \sum_{k=1}^K s_k \cdot W_{gk} + b_g$$ +### Auxiliary Pathway Loss -### MSigDB Hallmarks Initialization +To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership. -When `--pathway-init` is enabled, $\mathbf{W}_{recon}$ is initialized from the **MSigDB Hallmark gene sets** (50 curated biological pathways). A binary membership matrix $\mathbf{M} \in \{0, 1\}^{P \times G}$ is constructed: -$$M_{kg} = \begin{cases} 1 & \text{if gene } g \in \text{pathway } k \\ 0 & \text{otherwise} \end{cases}$$ +The total objective becomes: +$$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$ -The gene reconstructor weight is then initialized as $\mathbf{W}_{recon} \leftarrow \mathbf{M}^\top \in \mathbb{R}^{G \times P}$. This gives the model a biologically-grounded starting point where each pathway token is connected only to its known member genes. +The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE. ---- - -## 7. Loss Functions and Gene Imbalance - -### The Gene Imbalance Problem +The full training objective with pathway sparsity regularisation: +$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$ -Gene expression follows a **heavy-tailed distribution**: HOUSEKEEPING genes dominate, while signalling genes are rare. With standard MSE, high-expression genes dominate the loss. +--- -The `--log-transform` flag (`log1p`) is the primary mitigation. Additionally, the following loss objectives are supported: +## 5. CLI Flags (Model Configuration) -| CLI Flag | Objective | Focus | +| Flag | Default | Description | | :--- | :--- | :--- | -| `mse` | Mean Squared Error | Magnitude accuracy at each spot. | -| `pcc` | Pearson Correlation | Spatial pattern coherence per gene. | -| `mse_pcc` | Combined Loss | Balances absolute magnitude and spatial shape. | - -The complete training objective with pathway sparsity is: -$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|\mathbf{W}_{recon}\|_1$$ +| `--model interaction` | — | Select SpatialTranscriptFormer | +| `--backbone` | `resnet50` | Foundation model backbone | +| `--token-dim` | 256 | Transformer embedding dimension | +| `--n-heads` | 4 | Number of attention heads | +| `--n-layers` | 2 | Transformer layers (minimum 2) | +| `--num-pathways` | 50 | Number of pathway bottleneck tokens | +| `--pathway-init` | off | Initialize gene_reconstructor from MSigDB | +| `--sparsity-lambda` | 0.0 | L1 regularisation on reconstruction weights | +| `--loss mse_pcc` | `mse` | Loss function (`mse`, `pcc`, `mse_pcc`, `zinb`) | +| `--pcc-weight` | 1.0 | Weight for PCC term in composite loss | +| `--pathway-loss-weight` | 0.0 | Weight for auxiliary pathway loss ($\lambda_{aux}$) | +| `--interactions` | `all` | Enabled interaction quadrants (`p2p p2h h2p h2h`) | +| `--log-transform` | off | Apply log1p to targets | +| `--return-attention` | off | Return attention maps from forward pass (for diagnostics) | +| `--n-neighbors` | 0 | Number of context neighbors (for hybrid/GNN models) | diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 62aa736..3b0d9c0 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -126,18 +126,42 @@ python -m spatial_transcript_former.train \ > **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download. -### Advanced: Multimodal Masking (Ablation) +### Robust Counting: ZINB + Auxiliary Loss -Test the model's robustness by masking specific quadrants (e.g., top-left and bottom-right). +For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended. ```bash python -m spatial_transcript_former.train \ --data-dir A:\hest_data \ --model interaction \ - --masked-quadrants top_left bottom_right \ - --precomputed + --backbone ctranspath \ + --pathway-init \ + --loss zinb \ + --pathway-loss-weight 0.5 \ + --lr 5e-5 \ + --batch-size 4 \ + --whole-slide \ + --precomputed \ + --epochs 200 ``` +### Choosing Interaction Modes + +By default, the model runs in **Full Interaction** mode (`p2p p2h h2p h2h`) where all token types attend to each other. You can selectively disable interactions using the `--interactions` flag for ablation or to enforce specific architectural constraints. + +For example, to use the **Pathway Bottleneck** (blocking patch-to-patch attention for interpretability): + +```bash +python -m spatial_transcript_former.train \ + --data-dir A:\hest_data \ + --model interaction \ + --interactions p2p p2h h2p \ + --precomputed \ + --whole-slide +``` + +Available interaction tokens: `p2p`, `p2h`, `h2p`, `h2h`. Default is all four (Full Interaction). + --- ## 4. HPC Batch Experiments @@ -145,7 +169,7 @@ python -m spatial_transcript_former.train \ The `hpc/array_train.slurm` script runs all three whole-slide experiments as a SLURM array job: | Index | Model | Supervision | Key Flags | -|:------|:------|:------------|:----------| +| :--- | :--- | :--- | :--- | | 0 | SpatialTranscriptFormer | Dense | `--whole-slide` | | 1 | AttentionMIL | Weak | `--whole-slide --weak-supervision` | | 2 | TransMIL | Weak | `--whole-slide --weak-supervision` | @@ -173,7 +197,7 @@ This produces a sorted comparison table and `comparison.csv`. Each training run automatically produces: | File | Description | -|:-----|:------------| +| :--- | :--- | | `training_log.csv` | Per-epoch metrics (train_loss, val_loss, attn_correlation) | | `results_summary.json` | Full config + final metrics + runtime | | `best_model_.pth` | Best checkpoint (by val loss) | @@ -196,10 +220,13 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime | `--weak-supervision` | Bag-level training for MIL models. | Use with `attention_mil` or `transmil`. | | `--pathway-init` | Initialize gene_reconstructor from MSigDB Hallmarks. | Use with `interaction` model. | | `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. | +| `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. | +| `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. | +| `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). | | `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. | | `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. | | `--mask-radius` | Euclidean distance for spatial attention gating. | Usually between 200 and 800. | -| `--n-neighbors` | Number of context neighbors to load. | Set `> 0` for models using spatial context. | +| `--n-neighbors` | Number of context neighbors to load. | Set `> 0` for hybrid/GNN models. | | `--use-amp` | Mixed precision training. | Recommended on modern GPUs. | | `--grad-accum-steps` | Gradient accumulation steps. | Use when memory is limited. | | `--compile` | Use `torch.compile` for speed. | Recommended on Linux/A100. | diff --git a/scripts/diagnose_collapse.py b/scripts/diagnose_collapse.py new file mode 100644 index 0000000..d764933 --- /dev/null +++ b/scripts/diagnose_collapse.py @@ -0,0 +1,414 @@ +""" +Diagnostic script for detecting model collapse in SpatialTranscriptFormer. + +Loads a checkpoint and a sample, then checks: +1. Pathway token diversity (are all pathway embeddings collapsing to the same vector?) +2. Prediction variance (are all output genes/patches getting the same value?) +3. ZINB parameter health (are pi/mu/theta in sensible ranges?) +4. Gradient flow (are gradients reaching all layers?) +5. Attention entropy (are attention weights uniform or peaked?) + +Usage: + python scripts/diagnose_collapse.py --checkpoint runs/stf_interaction_zinb/best_model_interaction.pth --data-dir A:/hest_data +""" + +import argparse +import os +import sys +import json +import torch +import torch.nn.functional as F +import numpy as np + +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer +from spatial_transcript_former.data.dataset import ( + HEST_FeatureDataset, + load_global_genes, +) +from spatial_transcript_former.data.pathways import get_pathway_init, MSIGDB_URLS + + +def load_model( + checkpoint_path, device, num_genes=1000, backbone="ctranspath", loss="zinb" +): + """Load model from checkpoint.""" + gene_list_path = "global_genes.json" + if not os.path.exists(gene_list_path): + gene_list_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "global_genes.json", + ) + with open(gene_list_path) as f: + gene_list = json.load(f) + + urls = [MSIGDB_URLS["hallmarks"]] + pathway_init, pathway_names = get_pathway_init(gene_list[:num_genes], gmt_urls=urls) + num_pathways = len(pathway_names) + + model = SpatialTranscriptFormer( + num_genes=num_genes, + backbone_name=backbone, + pretrained=False, + num_pathways=num_pathways, + pathway_init=pathway_init, + output_mode="zinb" if loss == "zinb" else "counts", + ) + + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + # Handle torch.compile prefix + new_state_dict = {} + for k, v in state_dict.items(): + key = k.replace("_orig_mod.", "") + new_state_dict[key] = v + + model.load_state_dict(new_state_dict, strict=False) + model.to(device) + return model, pathway_names + + +def check_pathway_diversity(model): + """Check if pathway tokens have collapsed to the same vector.""" + pw = model.pathway_tokens.data.squeeze(0) # (P, D) + + # Pairwise cosine similarity + pw_norm = F.normalize(pw, dim=-1) + sim_matrix = pw_norm @ pw_norm.T # (P, P) + + # Off-diagonal similarities + mask = ~torch.eye(sim_matrix.shape[0], dtype=torch.bool, device=sim_matrix.device) + off_diag = sim_matrix[mask] + + mean_sim = off_diag.mean().item() + max_sim = off_diag.max().item() + min_sim = off_diag.min().item() + + # Check variance of each pathway token + pw_std = pw.std(dim=-1) # (P,) + dead_pathways = (pw_std < 1e-6).sum().item() + + print("\n" + "=" * 60) + print("1. PATHWAY TOKEN DIVERSITY") + print("=" * 60) + print(f" Num pathways: {pw.shape[0]}, Token dim: {pw.shape[1]}") + print(f" Mean pairwise cosine sim: {mean_sim:.4f}") + print(f" Max pairwise cosine sim: {max_sim:.4f}") + print(f" Min pairwise cosine sim: {min_sim:.4f}") + print(f" Dead pathways (zero std): {dead_pathways}") + + if mean_sim > 0.9: + print(" ⚠️ WARNING: Pathway tokens are highly similar — possible collapse!") + elif mean_sim > 0.7: + print(" ⚠️ CAUTION: Pathway tokens are somewhat similar") + else: + print(" ✅ Pathway tokens appear diverse") + + return mean_sim + + +def check_gene_reconstructor(model): + """Check if gene_reconstructor weights are degenerate.""" + W = model.gene_reconstructor.weight.data # (G, P) + + col_std = W.std(dim=0) # Per-pathway variance across genes + row_std = W.std(dim=1) # Per-gene variance across pathways + + dead_pathways = (col_std < 1e-6).sum().item() + dead_genes = (row_std < 1e-6).sum().item() + + # Check sparsity (fraction of near-zero weights) + sparsity = (W.abs() < 1e-4).float().mean().item() + + print("\n" + "=" * 60) + print("2. GENE RECONSTRUCTOR WEIGHTS") + print("=" * 60) + print(f" Shape: {W.shape}") + print(f" Weight range: [{W.min().item():.4f}, {W.max().item():.4f}]") + print(f" Weight std: {W.std().item():.4f}") + print(f" Dead pathways (zero col std): {dead_pathways}/{W.shape[1]}") + print(f" Dead genes (zero row std): {dead_genes}/{W.shape[0]}") + print(f" Sparsity (|w| < 1e-4): {sparsity:.2%}") + + if dead_pathways > W.shape[1] * 0.5: + print(" ⚠️ WARNING: >50% of pathways are dead in gene_reconstructor!") + else: + print(" ✅ Gene reconstructor weights appear healthy") + + +def check_predictions(model, feats, coords, device): + """Run a forward pass and check prediction diversity.""" + model.eval() + with torch.no_grad(): + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + output = model( + feats_t, return_dense=True, rel_coords=coords_t, return_pathways=True + ) + + if isinstance(output, tuple): + gene_output = output[0] + pathway_scores = output[1] + else: + gene_output = output + pathway_scores = None + + print("\n" + "=" * 60) + print("3. PREDICTION DIVERSITY") + print("=" * 60) + + if isinstance(gene_output, tuple): + # ZINB output: (pi, mu, theta) + pi, mu, theta = gene_output + pi, mu, theta = pi.squeeze(0), mu.squeeze(0), theta.squeeze(0) + + print(f" ZINB mode detected") + print(f" mu range: [{mu.min().item():.4f}, {mu.max().item():.4f}]") + print(f" mu mean: {mu.mean().item():.4f}") + print(f" mu std: {mu.std().item():.4f}") + print( + f" pi range: [{pi.min().item():.4f}, {pi.max().item():.4f}] (dropout probability)" + ) + print(f" pi mean: {pi.mean().item():.4f}") + print( + f" theta range: [{theta.min().item():.4f}, {theta.max().item():.4f}] (dispersion)" + ) + print(f" theta mean: {theta.mean().item():.4f}") + + # Spatial variance of mu (do predictions vary across patches?) + mu_spatial_std = mu.std(dim=0).mean().item() # avg per-gene spatial std + print(f"\n Spatial std of mu (per-gene avg): {mu_spatial_std:.6f}") + + # Per-patch variance (do predictions vary across genes?) + mu_gene_std = mu.std(dim=1).mean().item() + print(f" Gene std of mu (per-patch avg): {mu_gene_std:.6f}") + + if mu_spatial_std < 1e-4: + print( + " ⚠️ WARNING: Almost zero spatial variation — model is predicting the same thing everywhere!" + ) + elif mu_spatial_std < 1e-2: + print(" ⚠️ CAUTION: Very low spatial variation") + else: + print(" ✅ Spatial variation appears present") + + if pi.mean().item() > 0.9: + print( + " ⚠️ WARNING: pi is very high — model thinks almost everything is a dropout!" + ) + elif pi.mean().item() < 0.01: + print(" ✅ pi is low — model is not over-relying on zero inflation") + else: + preds = gene_output.squeeze(0) # (N, G) + print( + f" Prediction range: [{preds.min().item():.4f}, {preds.max().item():.4f}]" + ) + spatial_std = preds.std(dim=0).mean().item() + print(f" Spatial std (per-gene avg): {spatial_std:.6f}") + + if spatial_std < 1e-4: + print(" ⚠️ WARNING: Model is predicting the same thing everywhere!") + + if pathway_scores is not None: + pw = pathway_scores.squeeze(0) # (N, P) + print(f"\n Pathway scores shape: {pw.shape}") + print( + f" Pathway scores range: [{pw.min().item():.4f}, {pw.max().item():.4f}]" + ) + print(f" Pathway scores std: {pw.std().item():.4f}") + + # Per-pathway spatial variance + pw_spatial_std = pw.std(dim=0) # (P,) + print( + f" Per-pathway spatial std: min={pw_spatial_std.min().item():.4f}, max={pw_spatial_std.max().item():.4f}, mean={pw_spatial_std.mean().item():.4f}" + ) + + dead_pw = (pw_spatial_std < 1e-6).sum().item() + print(f" Dead pathways (zero spatial var): {dead_pw}/{pw.shape[1]}") + + +def check_attention_interactions(model, feats, coords, device): + """Analyze attention maps to see how different token groups are interacting.""" + model.eval() + with torch.no_grad(): + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + # SpatialTranscriptFormer.forward returns (gene_expr, pathway_scores, attentions) + # when return_pathways=True and return_attention=True + results = model( + feats_t, + rel_coords=coords_t, + return_pathways=True, + return_attention=True, + ) + + # Unpack results: (expr, p_scores, attentions) + if len(results) == 3: + attentions = results[2] + else: + # Fallback for unexpected return signature (shouldn't happen with my changes) + print( + " ⚠️ Could not extract attention maps (unexpected return signature)" + ) + return + + p = model.num_pathways + s = feats.shape[0] # number of spots + + print("\n" + "=" * 60) + print("5. ATTENTION INTERACTIONS") + print("=" * 60) + print(f" Interactions enabled: {model.interactions}") + + for i, attn in enumerate(attentions): + # attn is (B, T, T) where T = P + S + attn = attn.squeeze(0) # (T, T) + + # Interaction quadrants + p2p_m = attn[:p, :p].mean().item() + p2h_m = attn[:p, p:].mean().item() + h2p_m = attn[p:, :p].mean().item() + h2h_m = attn[p:, p:].mean().item() + + # Diagnostic for sparse attention: off-diagonal h2h + h2h_off_diag = attn[p:, p:].clone() + h2h_off_diag.fill_diagonal_(0) + h2h_off_diag_mean = h2h_off_diag.mean().item() + + print(f" LAYER {i}:") + print(f" p2p (Path -> Path): {p2p_m:.6f}") + print(f" p2h (Path -> Spot): {p2h_m:.6f}") + print(f" h2p (Spot -> Path): {h2p_m:.6f}") + print( + f" h2h (Spot -> Spot): {h2h_m:.6f} (off-diag: {h2h_off_diag_mean:.6f})" + ) + + # Warnings for "dead" interaction types + if h2p_m < 1e-4 and "h2p" in model.interactions: + print( + f" ⚠️ CAUTION: Spots are barely attending to pathways in Layer {i}" + ) + if p2h_m < 1e-4 and "p2h" in model.interactions: + print( + f" ⚠️ CAUTION: Pathways are barely attending to spots in Layer {i}" + ) + if h2h_off_diag_mean < 1e-7 and "h2h" in model.interactions: + print( + f" ⚠️ WARNING: Spot-to-spot attention is nearly zero despite being enabled!" + ) + + +def check_gradient_flow(model, feats, coords, device): + """Check if gradients reach all parts of the model.""" + model.train() + model.zero_grad() + + feats_t = feats.unsqueeze(0).to(device) + coords_t = coords.unsqueeze(0).to(device) + + output = model(feats_t, return_dense=True, rel_coords=coords_t) + + if isinstance(output, tuple): + loss = output[1].sum() # mu + else: + loss = output.sum() + + loss.backward() + + print("\n" + "=" * 60) + print("4. GRADIENT FLOW") + print("=" * 60) + + layers = { + "pathway_tokens": model.pathway_tokens, + "image_proj.weight": model.image_proj.weight, + "gene_reconstructor.weight": model.gene_reconstructor.weight, + "fusion_engine.layers[0].self_attn.in_proj_weight": model.fusion_engine.layers[ + 0 + ].self_attn.in_proj_weight, + } + + if hasattr(model, "pi_reconstructor"): + layers["pi_reconstructor.weight"] = model.pi_reconstructor.weight + layers["theta_reconstructor.weight"] = model.theta_reconstructor.weight + + for name, param in layers.items(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + grad_max = param.grad.abs().max().item() + print(f" {name:45s} grad_norm={grad_norm:.6f} grad_max={grad_max:.6f}") + if grad_norm < 1e-10: + print(f" ⚠️ DEAD gradient in {name}!") + else: + print(f" {name:45s} ⚠️ NO GRADIENT!") + + +def main(): + parser = argparse.ArgumentParser(description="Diagnose model collapse") + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--data-dir", type=str, required=True) + parser.add_argument("--num-genes", type=int, default=1000) + parser.add_argument("--backbone", type=str, default="ctranspath") + parser.add_argument("--loss", type=str, default="zinb") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Load model + print("Loading model...") + model, pathway_names = load_model( + args.checkpoint, device, args.num_genes, args.backbone, args.loss + ) + print(f"Loaded {len(pathway_names)} pathways") + + # Load a sample + print("Loading sample data...") + common_gene_names = load_global_genes(args.data_dir, args.num_genes) + + patches_dir = os.path.join(args.data_dir, "patches") + if not os.path.isdir(patches_dir): + patches_dir = args.data_dir + st_dir = os.path.join(args.data_dir, "st") + + feat_dir = os.path.join(args.data_dir, f"patches/he_features_{args.backbone}") + if not os.path.isdir(feat_dir): + feat_dir = os.path.join(args.data_dir, f"he_features_{args.backbone}") + + # Find first available sample + sample_files = [f for f in os.listdir(feat_dir) if f.endswith(".pt")] + if not sample_files: + print("ERROR: No feature files found!") + sys.exit(1) + + sample_id = sample_files[0].replace(".pt", "") + feature_path = os.path.join(feat_dir, sample_files[0]) + h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") + + print(f"Using sample: {sample_id}") + + ds = HEST_FeatureDataset( + feature_path, + h5ad_path, + num_genes=args.num_genes, + selected_gene_names=common_gene_names, + whole_slide_mode=True, + ) + + feats, targets, coords = ds[0] + print(f"Features shape: {feats.shape}, Coords shape: {coords.shape}") + + # Run diagnostics + check_pathway_diversity(model) + check_gene_reconstructor(model) + check_predictions(model, feats, coords, device) + check_attention_interactions(model, feats, coords, device) + check_gradient_flow(model, feats, coords, device) + + print("\n" + "=" * 60) + print("DIAGNOSIS COMPLETE") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/monitor.py b/scripts/monitor.py index e9c959f..a3270ce 100644 --- a/scripts/monitor.py +++ b/scripts/monitor.py @@ -90,7 +90,7 @@ def serve_image(filename): ), html.Div( [ - # Graph for Losses + # Row 1: Losses + Correlation html.Div( [dcc.Graph(id="live-loss-graph", animate=False)], style={ @@ -99,9 +99,25 @@ def serve_image(filename): "verticalAlign": "top", }, ), - # Graph for Metrics html.Div( - [dcc.Graph(id="live-metric-graph", animate=False)], + [dcc.Graph(id="live-pcc-graph", animate=False)], + style={ + "width": "48%", + "display": "inline-block", + "verticalAlign": "top", + }, + ), + # Row 2: Variance + Learning Rate + html.Div( + [dcc.Graph(id="live-variance-graph", animate=False)], + style={ + "width": "48%", + "display": "inline-block", + "verticalAlign": "top", + }, + ), + html.Div( + [dcc.Graph(id="live-lr-graph", animate=False)], style={ "width": "48%", "display": "inline-block", @@ -157,80 +173,97 @@ def update_image(n): return html.Img(src=url, style={"maxWidth": "100%", "height": "auto"}) +def _make_traces(df, cols, smoothing_window): + """Create Plotly traces for the given columns with optional smoothing.""" + traces = [] + for col in cols: + if col not in df.columns: + continue + y_data = df[col].dropna() + epochs = df.loc[y_data.index, "epoch"] + if smoothing_window and smoothing_window > 1: + y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + traces.append(go.Scatter(x=epochs, y=y_data, mode="lines", name=col)) + return traces + + @app.callback( [ Output("live-loss-graph", "figure"), - Output("live-metric-graph", "figure"), + Output("live-pcc-graph", "figure"), + Output("live-variance-graph", "figure"), + Output("live-lr-graph", "figure"), Output("last-updated", "children"), ], [Input("interval-component", "n_intervals"), Input("smoothing-slider", "value")], ) def update_graphs(n, smoothing_window): + empty = dash.no_update if not os.path.exists(log_path): - return ( - dash.no_update, - dash.no_update, - "Waiting for training_log.csv to be created...", - ) + return empty, empty, empty, empty, "Waiting for training_log.csv..." try: df = pd.read_csv(log_path) except Exception as e: - return dash.no_update, dash.no_update, f"Error reading log file: {str(e)}" + return empty, empty, empty, empty, f"Error reading log: {e}" if df.empty or "epoch" not in df.columns: - return ( - dash.no_update, - dash.no_update, - "Log file is empty or missing 'epoch' column.", - ) + return empty, empty, empty, empty, "Log empty or missing 'epoch'." - # Subplot 1: Losses - loss_cols = [c for c in df.columns if "loss" in c.lower()] - loss_traces = [] - for col in loss_cols: - y_data = df[col] - if smoothing_window and smoothing_window > 1: - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + margin = dict(l=40, r=40, t=40, b=40) - loss_traces.append(go.Scatter(x=df["epoch"], y=y_data, mode="lines", name=col)) - - loss_layout = go.Layout( - title="Training vs Validation Loss", - xaxis=dict(title="Epoch"), - yaxis=dict( - title="Loss", type="log" - ), # Log scale often helps with early MSE spikes - margin=dict(l=40, r=40, t=40, b=40), - ) + # Chart 1: Losses (log scale) + loss_cols = [c for c in df.columns if "loss" in c.lower()] + loss_fig = { + "data": _make_traces(df, loss_cols, smoothing_window), + "layout": go.Layout( + title="Loss", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Loss", type="log"), + margin=margin, + ), + } - # Subplot 2: Interpretability Metrics - metric_cols = [c for c in df.columns if c not in loss_cols and c != "epoch"] - metric_traces = [] - for col in metric_cols: - y_data = df[col] - if smoothing_window and smoothing_window > 1: - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() + # Chart 2: Correlation (PCC, MAE) + corr_cols = [c for c in ["val_pcc", "val_mae"] if c in df.columns] + pcc_fig = { + "data": _make_traces(df, corr_cols, smoothing_window), + "layout": go.Layout( + title="Correlation & Error", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Score"), + margin=margin, + ), + } - metric_traces.append( - go.Scatter(x=df["epoch"], y=y_data, mode="lines", name=col) - ) + # Chart 3: Prediction Variance + var_cols = [c for c in ["pred_variance"] if c in df.columns] + var_fig = { + "data": _make_traces(df, var_cols, smoothing_window), + "layout": go.Layout( + title="Prediction Variance (collapse detector)", + xaxis=dict(title="Epoch"), + yaxis=dict(title="Variance", type="log"), + margin=margin, + ), + } - metric_layout = go.Layout( - title="Interpretability Metrics (MAE, PCC, Correlation)", - xaxis=dict(title="Epoch"), - yaxis=dict(title="Score / Error"), - margin=dict(l=40, r=40, t=40, b=40), - ) + # Chart 4: Learning Rate + lr_cols = [c for c in ["lr"] if c in df.columns] + lr_fig = { + "data": _make_traces(df, lr_cols, smoothing_window), + "layout": go.Layout( + title="Learning Rate Schedule", + xaxis=dict(title="Epoch"), + yaxis=dict(title="LR", type="log"), + margin=margin, + ), + } last_epoch = df["epoch"].iloc[-1] update_text = f"Last updated: Epoch {last_epoch} (Polled automatically)" - return ( - {"data": loss_traces, "layout": loss_layout}, - {"data": metric_traces, "layout": metric_layout}, - update_text, - ) + return loss_fig, pcc_fig, var_fig, lr_fig, update_text @app.callback( diff --git a/scripts/run_preset.py b/scripts/run_preset.py index 3e60dfa..8e51d2b 100644 --- a/scripts/run_preset.py +++ b/scripts/run_preset.py @@ -3,7 +3,26 @@ import sys import os +from spatial_transcript_former.config import get_config + +# Common flags for all STF interaction models +STF_COMMON = [ + "--model", + "interaction", + "--backbone", + "ctranspath", + "--precomputed", + "--whole-slide", + "--pathway-init", + "--use-amp", + "--log-transform", + "--loss", + "mse_pcc", + "--resume", +] + PRESETS = { + # --- Baselines --- "he2rna_baseline": [ "--model", "he2rna", @@ -36,58 +55,42 @@ "--batch-size", "1", ], - "stf_pathway_nystrom": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--use-nystrom", - "--pathway-init", - "--sparsity-lambda", - "0.05", - "--lr", - "1e-4", + # --- Interaction Models (Layer Scaling) --- + "stf_interaction_l2": STF_COMMON + + [ + "--n-layers", + "2", + "--token-dim", + "256", + "--n-heads", + "4", "--batch-size", - "8", - "--epochs", - "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse_pcc", - "--resume", + "4", ], - "stf_pathway": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", - "--pathways", - "KEGG_COLORECTAL_CANCER", - "GRADE_COLON_CANCER_UP", - "GRADE_COLON_CANCER_DN", - "GRADE_COLON_AND_RECTAL_CANCER_UP", - "GRADE_COLON_AND_RECTAL_CANCER_DN", - "--sparsity-lambda", - "0.01", - "--lr", - "1e-5", + "stf_interaction_l4": STF_COMMON + + [ + "--n-layers", + "4", + "--token-dim", + "384", + "--n-heads", + "8", "--batch-size", + "4", + ], + "stf_interaction_l6": STF_COMMON + + [ + "--n-layers", + "6", + "--token-dim", + "512", + "--n-heads", "8", - "--epochs", - "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse_pcc", - "--resume", + "--batch-size", + "2", # Reduced batch size for large model memory ], - "stf_pathway_hybrid": [ + # --- Specific Configurations --- + "stf_interaction_zinb": [ "--model", "interaction", "--backbone", @@ -95,54 +98,49 @@ "--precomputed", "--whole-slide", "--pathway-init", - "--pathways", - "KEGG_COLORECTAL_CANCER", - "GRADE_COLON_CANCER_UP", - "GRADE_COLON_CANCER_DN", - "GRADE_COLON_AND_RECTAL_CANCER_UP", - "GRADE_COLON_AND_RECTAL_CANCER_DN", "--sparsity-lambda", - "0.01", + "0", "--lr", "1e-4", "--batch-size", - "8", + "4", "--epochs", "2000", - "--log-transform", "--use-amp", "--loss", - "mse_pcc", + "zinb", + "--log-transform", + "--pathway-loss-weight", + "0.5", + "--interactions", + "p2p", + "p2h", + "h2p", + "h2h", + "--plot-pathways", "--resume", ], - "stf_pathway_gnn": [ - "--model", - "interaction", - "--backbone", - "ctranspath", - "--precomputed", - "--whole-slide", - "--pathway-init", + # Legacy / Specialized + "stf_pathway_nystrom": STF_COMMON + + [ "--sparsity-lambda", - "0.01", + "0.05", "--lr", "1e-4", "--batch-size", "8", "--epochs", "2000", - "--log-transform", - "--use-amp", - "--loss", - "mse", - "--resume", - "--early-mixer", - "none", - "--late-refiner", - "gnn", ], } +# Add alias for the one currently running to ensure backward compatibility for монитор +PRESETS["stf_interaction_mse_pcc"] = PRESETS["stf_interaction_l2"] + [ + "--pathway-loss-weight", + "0.5", + "--plot-pathways", +] + def main(): parser = argparse.ArgumentParser( @@ -161,7 +159,7 @@ def main(): default=get_config("data_dirs", ["A:\\hest_data"])[0], help="Data directory", ) - parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") + parser.add_argument("--epochs", type=int, default=2000, help="Number of epochs") parser.add_argument("--max-samples", type=int, default=None, help="Limit samples") parser.add_argument( "--output-dir", type=str, default=None, help="Override output dir" diff --git a/setup.ps1 b/setup.ps1 index 896303f..9c1bd04 100644 --- a/setup.ps1 +++ b/setup.ps1 @@ -7,8 +7,8 @@ $EnvName = "SpatialTranscriptFormer" # Check if conda environment exists $CondaEnv = conda env list | Select-String $EnvName if ($null -eq $CondaEnv) { - Write-Host "Creating conda environment '$EnvName' with Python 3.10..." -ForegroundColor Yellow - conda create -n $EnvName python=3.10 -y + Write-Host "Creating conda environment '$EnvName' with Python 3.9..." -ForegroundColor Yellow + conda create -n $EnvName python=3.9 -y } else { Write-Host "Conda environment '$EnvName' already exists." -ForegroundColor Green @@ -22,5 +22,6 @@ Write-Host "Setup Complete!" -ForegroundColor Green Write-Host "You can now use the following commands:" Write-Host " stf-download --help" Write-Host " stf-split --help" +Write-Host " stf-build-vocab --help" Write-Host "" Write-Host "To run tests, use: .\test.ps1" diff --git a/setup.sh b/setup.sh index 4a18e26..a16c23e 100644 --- a/setup.sh +++ b/setup.sh @@ -7,8 +7,8 @@ ENV_NAME="SpatialTranscriptFormer" # Check if conda environment exists if ! conda env list | grep -q "$ENV_NAME"; then - echo "Creating conda environment '$ENV_NAME' with Python 3.10..." - conda create -n $ENV_NAME python=3.10 -y + echo "Creating conda environment '$ENV_NAME' with Python 3.9..." + conda create -n $ENV_NAME python=3.9 -y else echo "Conda environment '$ENV_NAME' already exists." fi @@ -20,7 +20,7 @@ echo "" echo "Setup Complete!" echo "You can now use the following commands (after activating the environment):" echo " stf-download --help" -echo " stf-download-bowel" echo " stf-split --help" +echo " stf-build-vocab --help" echo "" echo "To run tests, use: ./test.sh" diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 34d50c4..1832f7d 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -2,14 +2,9 @@ Pathway-histology interaction layers for SpatialTranscriptFormer. This module defines the building blocks that fuse learnable pathway tokens -with histology patch features via self- and cross-attention: - -* ``PathwayTokenizer`` — learnable pathway embeddings -* ``LocalPatchMixer`` — scatter-gather depthwise conv for neighbourhood mixing -* ``SinusoidalPositionalEncoder`` — 2-D sinusoidal positional embeddings -* ``EarlyFusionBlock`` — concat-attention-slice wrapper (Jaume-mode) -* ``MultimodalFusion`` — standard Transformer early fusion -* ``NystromEncoder`` — Nyström-based early fusion +with histology patch features via self-attention: + +* ``LearnedSpatialEncoder`` — 2-D learned positional embeddings * ``SpatialTranscriptFormer`` — the full model """ @@ -19,449 +14,53 @@ from .backbones import get_backbone -class PathwayTokenizer(nn.Module): - """Projects pathway indices or data into latent embeddings. - - This module maintains a set of learnable embeddings of size dim for each - of the num_pathways defined. - - Attributes: - num_pathways (int): Number of pathway tokens to generate. - dim (int): Dimension of each pathway token. - """ - - def __init__(self, num_pathways, dim): - """Initializes the PathwayTokenizer. - - Args: - num_pathways (int): Total number of pathways. - dim (int): Token embedding dimension. - """ - super().__init__() - self.num_pathways = num_pathways - self.dim = dim - self.pathway_embeddings = nn.Parameter(torch.randn(1, num_pathways, dim)) - - def forward(self, batch_size): - """Generates pathway tokens for the batch. - - Args: - batch_size (int): Number of samples in the batch. - - Returns: - torch.Tensor: Pathway tokens of shape (batch_size, num_pathways, dim). - """ - return self.pathway_embeddings.expand(batch_size, -1, -1) - - -class LocalPatchMixer(nn.Module): - """Mixes each patch's features with its immediate spatial neighbours via a depthwise conv. - - The forward pass performs three steps: +class LearnedSpatialEncoder(nn.Module): + """Encodes 2D spatial coordinates via a small learned MLP. - 1. **Scatter** — place each patch feature vector into a 2-D dense grid at its - grid-coordinate position. Coordinates are zero-based and must be - integer-like (grid indices, *not* raw pixel coordinates). - 2. **Depthwise Conv + GELU** — apply a ``kernel_size x kernel_size`` grouped - convolution over the spatial grid. Each channel is processed independently - so the parameter count scales with *dim*, not *dim²*. - 3. **Gather + Residual** — read the convolved values back at the same grid - positions and add them to the original features (residual connection). - - A safety guard skips mixing if the bounding-box of grid coordinates exceeds - 256×256 cells, which would indicate malformed / non-grid coordinates and - could allocate prohibitive memory. + Unlike sinusoidal PE, this produces smooth, non-periodic embeddings + that vary gradually across the tissue. Coordinates are normalised to + [-1, 1] per-batch before encoding for scale invariance. """ - def __init__(self, dim, kernel_size=3): + def __init__(self, dim): super().__init__() - self.dim = dim - self.conv = nn.Conv2d( - dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim + self.mlp = nn.Sequential( + nn.Linear(2, dim), + nn.GELU(), + nn.Linear(dim, dim), ) - self.act = nn.GELU() - - def forward(self, x, coords): - """ - Args: - x: (B, N, D) Patch features - coords: (B, N, 2) Coordinates (absolute-ish indices) - Returns: - (B, N, D) Enriched features - """ - B, N, D = x.shape - device = x.device - - # 1. Normalize coords to grid indices - # We assume coords are already roughly grid-like or pixel coords. - # But to be safe, we subtract min and assume unit step? - # Actually, let's assume coords ARE grid indices for now, or scaled pixels. - # If they are large pixels (e.g. 10000), we need to know the patch size (224). - # We'll trust the caller to pass grid-indices. - - # Find bounds per batch? Or global max? - # To batch this efficiently, we find global bounds in the batch - min_c = coords.min(dim=1, keepdim=True)[0] # (B, 1, 2) - grid_coords = coords - min_c # Zero-based - - max_c = grid_coords.max(dim=1)[0].max(dim=0)[0] # (2,) - H, W = int(max_c[1].item()) + 1, int(max_c[0].item()) + 1 - - # Cap memory usage if outliers exist - if H * W > 256 * 256: - # Fallback: Don't crash on massive outlier. Just skip mix? - # Or clamp? Let's skip mixing if grid is absurd. - return x - - # 2. Scatter - grid = torch.zeros(B, D, H, W, device=device, dtype=x.dtype) - - b_idx = torch.arange(B, device=device).view(B, 1, 1) - d_idx = torch.arange(D, device=device).view(1, D, 1) - - gx = grid_coords[..., 0].long() # (B, N) - gy = grid_coords[..., 1].long() # (B, N) - - gy_idx = gy.unsqueeze(1) # (B, 1, N) - gx_idx = gx.unsqueeze(1) # (B, 1, N) - - x_T = x.transpose(1, 2) # (B, D, N) - grid[b_idx, d_idx, gy_idx, gx_idx] = x_T - - # 3. Conv - out_grid = self.act(self.conv(grid)) - - # 4. Gather & Residual - out = out_grid[b_idx, d_idx, gy_idx, gx_idx].transpose( - 1, 2 - ) # (B, D, N) -> (B, N, D) - - return x + out - - -class GraphPatchMixer(nn.Module): - """Mixes each patch's features with its k-nearest physical neighbors using Graph Attention. - - This acts as a spatial refiner. It constructs a k-NN graph on the fly from the - physical coordinates and performs message passing. - """ - - def __init__(self, dim, k=8, heads=4): - super().__init__() - self.dim = dim - self.k = k - self.heads = heads - self.head_dim = dim // heads - - assert self.head_dim * heads == dim, "dim must be divisible by heads" - - # GAT-style linear projections - self.to_qkv = nn.Linear(dim, dim * 3, bias=False) - self.proj = nn.Linear(dim, dim) - self.act = nn.GELU() - - def forward(self, x, coords, mask=None): - """ - Args: - x: (B, N, D) Patch features - coords: (B, N, 2) Coordinates (absolute physical positions) - mask: (B, N) Boolean padding mask (True = padding) - - Returns: - (B, N, D) Refined features - """ - B, N, D = x.shape - device = x.device - - # 1. Build k-NN graph - # Compute pairwise distances (B, N, N) - dist = torch.cdist(coords, coords) - - k_actual = min(self.k + 1, N) # +1 for self-loop, bounded by N - dist_for_topk = dist.clone() - if mask is not None: - dist_for_topk.masked_fill_(mask.unsqueeze(1).expand(B, N, N), float("inf")) - - _, nn_idx = torch.topk(-dist_for_topk, k=k_actual, dim=-1) # (B, N, K) - - # 2. Extract neighbor features - batch_indices = torch.arange(B, device=device).view(-1, 1, 1) # (B, 1, 1) - x_neighbors = x[batch_indices, nn_idx, :] # (B, N, K, D) - - # 3. Message Passing (GAT-style) - # Query comes from the center node, Key/Value come from neighbors - qkv_center = self.to_qkv(x) # (B, N, 3D) - q_center, _, _ = qkv_center.chunk(3, dim=-1) # (B, N, D) - - qkv_neighbors = self.to_qkv(x_neighbors) - _, k_neighbors, v_neighbors = qkv_neighbors.chunk(3, dim=-1) # (B, N, K, D) - - # Reshape for multi-head attention - q = q_center.view(B, N, self.heads, self.head_dim) # (B, N, H, d) - k = k_neighbors.view( - B, N, k_actual, self.heads, self.head_dim - ) # (B, N, K, H, d) - v = v_neighbors.view( - B, N, k_actual, self.heads, self.head_dim - ) # (B, N, K, H, d) - - # Attention scores: Q * K^T - q = q.unsqueeze(3) # (B, N, H, 1, d) - k = k.permute(0, 1, 3, 2, 4) # (B, N, H, K, d) - - # dot product over 'd' - attn = (q * k).sum(dim=-1) / (self.head_dim**0.5) # (B, N, H, K) - - if mask is not None: - batch_indices = torch.arange(B, device=device).view(-1, 1, 1) # (B, 1, 1) - neighbor_is_padded = mask[batch_indices, nn_idx] # (B, N, K) - neighbor_is_padded = neighbor_is_padded.unsqueeze(2).expand( - -1, -1, self.heads, -1 - ) - attn = attn.masked_fill(neighbor_is_padded, float("-inf")) - - attn = F.softmax(attn, dim=-1) # (B, N, H, K) - if mask is not None: - attn = torch.nan_to_num(attn, nan=0.0) - - # Weighted sum of values - v = v.permute(0, 1, 3, 2, 4) # (B, N, H, K, d) - out = (attn.unsqueeze(-1) * v).sum(dim=-2) # (B, N, H, d) - - # Reshape back to D - out = out.reshape(B, N, D) - out = self.proj(self.act(out)) - - # 4. Residual Connection - return x + out - - -class SinusoidalPositionalEncoder(nn.Module): - """Encodes 2D spatial coordinates into sinusoidal embeddings. - - - Based on the "Attention is All You Need" 1D PE, extended to 2D. - Each spatial dimension (x, y) is encoded separately with dim/2 channels. - """ - - def __init__(self, dim, temperature=10000): - super().__init__() - self.dim = dim - self.temperature = temperature + def _normalize_coords(self, coords): + """Normalize coordinates to [-1, 1] range per batch.""" + # Centre at zero + coords = coords - coords.mean(dim=1, keepdim=True) + # Scale to [-1, 1] + scale = coords.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1.0) + return coords / scale def forward(self, rel_coords): """ Args: - rel_coords (torch.Tensor): (B, S, 2) relative coordinates. + rel_coords (torch.Tensor): (B, S, 2) spatial coordinates. Returns: torch.Tensor: (B, S, dim) positional embeddings. """ - x = rel_coords[..., 0] - y = rel_coords[..., 1] - - # Split dim into two halves for x and y - dim_x = self.dim // 2 - dim_y = self.dim - dim_x - - # Create geometric progression of frequencies - # div_term = 1 / (temperature ** (2i / d_model)) - div_term_x = torch.exp( - torch.arange(0, dim_x, 2, dtype=torch.float32, device=rel_coords.device) - * -(torch.log(torch.tensor(self.temperature)) / dim_x) - ) - div_term_y = torch.exp( - torch.arange(0, dim_y, 2, dtype=torch.float32, device=rel_coords.device) - * -(torch.log(torch.tensor(self.temperature)) / dim_y) - ) - - pe_x = torch.zeros(*x.shape, dim_x, device=rel_coords.device) - pe_y = torch.zeros(*y.shape, dim_y, device=rel_coords.device) - - # Sin/Cos pairs for X - pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term_x) - pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term_x) - - # Sin/Cos pairs for Y - pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1) * div_term_y) - pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1) * div_term_y) - - # Concatenate X and Y embeddings -> (B, S, dim) - return torch.cat([pe_x, pe_y], dim=-1) - - -class EarlyFusionBlock(nn.Module): - """Generic wrapper for Early Fusion interaction. - - Handles the common logic for "concatenation -> attention -> slicing" used in - multimodal interaction (Jaume et al. mode). - """ - - def __init__(self, encoder, masked_quadrants=None): - """Initializes EarlyFusionBlock. - - Args: - encoder (nn.Module): The attention mechanism (Standard or Nystrom). - masked_quadrants (list, optional): Quadrants to mask. - """ - super().__init__() - self.encoder = encoder - self.masked_quadrants = masked_quadrants or [] - - def _generate_quadrant_mask(self, num_p, num_h, device): - """Generates Eq 1 quadrant mask to restrict attention between modalities. - - [ P2P P2H ] - [ H2P H2H ] - """ - n_total = num_p + num_h - mask = torch.zeros((n_total, n_total), device=device, dtype=torch.bool) - - p_slice = slice(0, num_p) - h_slice = slice(num_p, n_total) - - if "H2H" in self.masked_quadrants: - mask[h_slice, h_slice] = True - if "H2P" in self.masked_quadrants: - mask[h_slice, p_slice] = True - if "P2H" in self.masked_quadrants: - mask[p_slice, h_slice] = True - - return mask - - def forward(self, p_tokens, h_tokens, return_all_tokens=False): - """Performs multimodal fusion. - - Args: - p_tokens: Pathway tokens (B, Np, D) - h_tokens: Histology tokens (B, Nh, D) - return_all_tokens: If True, returns full concatenated sequence. - - Returns: - Contextualized tokens. - """ - np = p_tokens.shape[1] - nh = h_tokens.shape[1] - - x = torch.cat([p_tokens, h_tokens], dim=1) - - mask = None - if self.masked_quadrants: - mask = self._generate_quadrant_mask(np, nh, p_tokens.device) - - # Standard PyTorch MHA and NystromAttention use different mask signatures. - # We handle the dispatch here to ensure correct mask application. - out = self.encoder(x, mask=mask) - - if return_all_tokens: - return out - - return out[:, :np, :] - - -class MultimodalFusion(EarlyFusionBlock): - """Standard Transformer-based Early Fusion.""" - - def __init__(self, dim, n_heads, n_layers, dropout=0.1, masked_quadrants=None): - encoder_layer = nn.TransformerEncoderLayer( - d_model=dim, - nhead=n_heads, - dim_feedforward=dim * 4, - dropout=dropout, - batch_first=True, - ) - transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) - super().__init__(encoder=transformer, masked_quadrants=masked_quadrants) + return self.mlp(self._normalize_coords(rel_coords)) -class NystromStack(nn.Module): - """Sequential stack of Nyström attention layers, compatible with ``EarlyFusionBlock``. - - Each element of *layers* is a ``nn.ModuleList`` of - ``(norm1, NystromAttention, norm2, FeedForward)`` sub-modules following a - pre-norm residual layout. The attention mask convention differs from - standard PyTorch MHA: ``True`` means *keep* (attend), so the mask is - inverted (``~mask``) before being passed to ``NystromAttention``. - """ - - def __init__(self, layers): - super().__init__() - self.layers = nn.ModuleList(layers) - - def forward(self, x, mask=None): - attn_mask = ~mask if mask is not None else None - - for norm1, attn, norm2, ff in self.layers: - x = x + attn(norm1(x), mask=attn_mask) - x = x + ff(norm2(x)) - return x - - -class NystromEncoder(EarlyFusionBlock): - """Nystrom-based Early Fusion.""" - - def __init__( - self, - dim, - n_heads, - n_layers, - dropout=0.1, - num_landmarks=256, - masked_quadrants=None, - ): - from nystrom_attention import NystromAttention - - layers = [] - for _ in range(n_layers): - layers.append( - nn.ModuleList( - [ - nn.LayerNorm(dim), - NystromAttention( - dim=dim, - heads=n_heads, - dim_head=dim // n_heads, - num_landmarks=num_landmarks, - dropout=dropout, - ), - nn.LayerNorm(dim), - nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(dim * 4, dim), - nn.Dropout(dropout), - ), - ] - ) - ) - - stack = NystromStack(layers) - super().__init__(encoder=stack, masked_quadrants=masked_quadrants) - - -# Removed NystromDecoderLayer and NystromDecoder as we are focusing solely on the Jaume pipeline. +VALID_INTERACTIONS = {"p2p", "p2h", "h2p", "h2h"} class SpatialTranscriptFormer(nn.Module): """Transformer for predicting gene expression from histology and spatial context. Integrates histology feature extraction with pathway-based bottleneck - attention to predict gene transcript counts. Supports standard decoder-style - interaction or early fusion multimodal interaction. - - Architectural Inspiration: - Jaume et al. (2024). "Modeling Dense Multimodal Interactions Between - Biological Pathways and Histology for Survival Prediction" (SurvPath). - - Benchmark/Framework: - Jaume et al. (2024). "HEST-1k: A Dataset for Spatial Transcriptomics - and Histology Image Analysis." NeurIPS (Spotlight). + attention to predict gene transcript counts. Follows a standard Vision + Transformer architecture where pathway tokens act as [CLS]-like bottlenecks. Attributes: num_pathways (int): Number of pathway bottlenecks. - use_nystrom (bool): Whether to use efficient Nystrom attention. """ def __init__( @@ -470,19 +69,14 @@ def __init__( num_pathways=50, backbone_name="resnet50", pretrained=True, - token_dim=512, - n_heads=8, + token_dim=256, + n_heads=4, n_layers=2, dropout=0.1, - use_nystrom=False, - mask_radius=None, - masked_quadrants=None, - num_landmarks=256, pathway_init=None, use_spatial_pe=True, - early_mixer="conv", - late_refiner=None, - k_neighbors=8, + output_mode="counts", + interactions=None, ): """Initializes the SpatialTranscriptFormer. @@ -495,27 +89,42 @@ def __init__( n_heads (int): Number of attention heads. n_layers (int): Number of transformer/interaction layers. dropout (float): Dropout probability. - use_nystrom (bool): Enable linear-complexity attention. - mask_radius (float, optional): Distance-based masking threshold. - masked_quadrants (list, optional): Mask configuration for fusion. - num_landmarks (int): Landmarks for Nystrom attention. pathway_init (Tensor, optional): Biological pathway membership matrix of shape (P, G) to initialize gene_reconstructor. use_spatial_pe (bool): Incorporate relative gradients into attention. - early_mixer (str, optional): 'conv' or None. - late_refiner (str, optional): 'gnn' or None. - k_neighbors (int): k-NN for GNN refiner. + output_mode (str): 'counts' (standard MSE/PCC) or 'zinb' (Zero-Inflated Negative Binomial outputs). + interactions (list[str], optional): Which attention interactions to + enable. Valid keys are ``p2p``, ``p2h``, ``h2p``, ``h2h``. + Defaults to all four (full self-attention). """ super().__init__() + if interactions is None: + interactions = list(VALID_INTERACTIONS) + unknown = set(interactions) - VALID_INTERACTIONS + if unknown: + raise ValueError( + f"Unknown interaction keys: {unknown}. " + f"Valid keys are: {VALID_INTERACTIONS}" + ) + self.interactions = set(interactions) + + # Enforce minimum 2 layers when h2h is blocked. + # Layer 1 lets pathways gather patch info, Layer 2 lets patches + # read the now-informed pathway tokens. + if "h2h" not in self.interactions and n_layers < 2: + raise ValueError( + f"n_layers must be >= 2 when h2h is not in interactions. " + f"Got n_layers={n_layers}. Layer 1 lets pathways gather spatial info, " + f"Layer 2 lets patches read contextualized pathways." + ) + # Override num_pathways if biological init is provided if pathway_init is not None: num_pathways = pathway_init.shape[0] print(f"Pathway init: overriding num_pathways to {num_pathways}") self.num_pathways = num_pathways - self.use_nystrom = use_nystrom - self.mask_radius = mask_radius self.use_spatial_pe = use_spatial_pe # 1. Image Encoder Backbone @@ -523,50 +132,32 @@ def __init__( backbone_name, pretrained=pretrained ) - # 2. Image Projector & Spatial Modules + # 2. Image Projector self.image_proj = nn.Linear(self.image_feature_dim, token_dim) - self.early_mixer = None - if early_mixer == "conv": - self.early_mixer = LocalPatchMixer(token_dim) - - self.late_refiner = None - if late_refiner == "gnn": - self.late_refiner = GraphPatchMixer( - dim=token_dim, k=k_neighbors, heads=n_heads - ) - - self.pathway_tokenizer = PathwayTokenizer(num_pathways, token_dim) + # 3. Learnable pathway tokens (one per pathway, shared across batch) + self.pathway_tokens = nn.Parameter(torch.randn(1, num_pathways, token_dim)) - # 3b. Spatial Positional Encoder + # 4. Spatial Positional Encoder (optional) self.spatial_encoder = None if use_spatial_pe: - self.spatial_encoder = SinusoidalPositionalEncoder(token_dim) + self.spatial_encoder = LearnedSpatialEncoder(token_dim) - # 4. Interaction Engine (Unified Early Fusion Path) - if use_nystrom: - if masked_quadrants: - print( - "Warning: Nystrom attention does not support 2D quadrant masking. Mask will be ignored." - ) - self.fusion_engine = NystromEncoder( - token_dim, - n_heads, - n_layers, - dropout=dropout, - num_landmarks=num_landmarks, - masked_quadrants=None, - ) - else: - self.fusion_engine = MultimodalFusion( - token_dim, - n_heads, - n_layers, - dropout=dropout, - masked_quadrants=masked_quadrants, - ) + # 5. Interaction Engine (Standard Transformer) + encoder_layer = nn.TransformerEncoderLayer( + d_model=token_dim, + nhead=n_heads, + dim_feedforward=token_dim * 4, + dropout=dropout, + batch_first=True, + norm_first=True, + ) + self.fusion_engine = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + # Learnable temperature for cosine similarity scoring + # Initialized to log(1/0.07) ≈ 2.66 following CLIP convention + self.log_temperature = nn.Parameter(torch.tensor(2.6593)) - self.pathway_activator = nn.Linear(token_dim, 1) self.gene_reconstructor = nn.Linear(num_pathways, num_genes) if pathway_init is not None: @@ -575,215 +166,212 @@ def __init__( # pathway_init is (num_pathways, num_genes) self.gene_reconstructor.weight.copy_(pathway_init.T) self.gene_reconstructor.bias.zero_() + # Expose the MSigDB matrix for AuxiliaryPathwayLoss + self._pathway_init_matrix = pathway_init.clone() print("Initialized gene_reconstructor with MSigDB Hallmarks") + else: + self._pathway_init_matrix = None - # Dense Head (Whole Slide Mode) - self.gene_head = nn.Linear(token_dim, num_genes) + self.output_mode = output_mode + if self.output_mode == "zinb": + # pi: probability of dropout (zero-inflation) + self.pi_reconstructor = nn.Linear(num_pathways, num_genes) + # theta: inverse dispersion + self.theta_reconstructor = nn.Linear(num_pathways, num_genes) - def get_sparsity_loss(self): - """Computes L1 norm of reconstruction weights for sparsity regularization. + # Initialize ZINB specialized heads carefully to avoid immediate collapse + with torch.no_grad(): + # Initialize pi aggressively negative so sigmoid(pi) is near 0.05 + nn.init.normal_(self.pi_reconstructor.weight, std=0.01) + nn.init.constant_(self.pi_reconstructor.bias, -3.0) - Returns: - torch.Tensor: L1 loss value. - """ - return torch.norm(self.gene_reconstructor.weight, p=1) + # Initialize theta so softplus(theta) is roughly 1.0 + nn.init.normal_(self.theta_reconstructor.weight, std=0.01) + nn.init.constant_(self.theta_reconstructor.bias, 0.5) - def _normalize_coords(self, coords): - """ - Infer grid scaling from absolute coordinates. - (e.g., convert 224-pixel steps into unit steps). + def _build_interaction_mask(self, p, s, device): + """Build ``(P+S, P+S)`` boolean attention mask from ``self.interactions``. - Args: - coords: (B, N, 2) absolute pixel coordinates. + In PyTorch transformer convention, ``True`` means **blocked**. Returns: - (B, N, 2) grid-aligned indices. + torch.Tensor or None: Mask tensor, or ``None`` when all + interactions are enabled (no masking needed). """ - if coords is None: - return None - - # Zero-base - min_c = coords.min(dim=1, keepdim=True)[0] - zero_c = coords - min_c - - # Infer step per batch - # We use the median of non-zero adjacent differences to find step size - B, N, _ = zero_c.shape - if N < 2: - return torch.zeros_like(zero_c) - - steps = [] - for i in [0, 1]: # X and Y - c_sorted, _ = torch.sort(zero_c[..., i], dim=1) - diffs = c_sorted[:, 1:] - c_sorted[:, :-1] - diffs = diffs[diffs > 0] - step = diffs.median() if diffs.numel() > 0 else 1.0 - steps.append(step) - - grid_coords = torch.stack( - [ - torch.round(zero_c[..., 0] / steps[0]), - torch.round(zero_c[..., 1] / steps[1]), - ], - dim=-1, - ) - return grid_coords - - def _generate_spatial_mask(self, rel_coords): - """Generates distance-based masks for neighborhood attention. + if self.interactions >= VALID_INTERACTIONS: + return None # everything enabled, skip masking + + total = p + s + # Start fully blocked + mask = torch.ones(total, total, dtype=torch.bool, device=device) + + if "p2p" in self.interactions: + mask[:p, :p] = False + if "p2h" in self.interactions: + mask[:p, p:] = False + if "h2p" in self.interactions: + mask[p:, :p] = False + if "h2h" in self.interactions: + mask[p:, p:] = False + + # Always allow self-attention (diagonal) + mask.fill_diagonal_(False) + return mask - Args: - rel_coords (torch.Tensor): Relative coordinates (B, S, 2). + def get_sparsity_loss(self): + """Computes L1 norm of reconstruction weights for sparsity regularization. Returns: - torch.Tensor: Boolean mask (B, S) where True means ignore. + torch.Tensor: L1 loss value. """ - if self.mask_radius is None: - return None - - # Calculate Euclidean distances from center (0, 0) - dists = torch.norm(rel_coords, dim=-1) - - # Mask elements that are beyond the specified radius - mask = dists > self.mask_radius - return mask + return torch.norm(self.gene_reconstructor.weight, p=1) - def forward(self, x, rel_coords=None, return_pathways=False): + def forward( + self, + x, + rel_coords=None, + return_pathways=False, + mask=None, + return_dense=False, + return_attention=False, + ): """Main inference path. Args: x (torch.Tensor): Image data or pre-computed features. - (B, 3, H, W): Single image patch. - - (B, S, 3, H, W): Image neighborhood. - (B, S, D): Pre-computed features. rel_coords (torch.Tensor, optional): Spatial relative coordinates. return_pathways (bool): Whether to return pathway activations. + mask (torch.Tensor, optional): Boolean padding mask for patches (B, S) where True = Padding. + return_dense (bool): If True, returns per-patch gene predictions instead of global predictions. + return_attention (bool): If True, returns attention maps from all layers. Returns: - torch.Tensor: Predicted gene counts (B, num_genes). - (Optional) torch.Tensor: Pathway activations (B, num_pathways). + torch.Tensor: Predicted gene counts (B, num_genes) or (B, N, num_genes) if return_dense. + (Optional) torch.Tensor: Pathway activations/scores. + (Optional) list[torch.Tensor]: Attention maps [L, B, H, T, T] if return_attention. """ - if x.dim() == 5: - # Neighborhood Mode: Extract features for all patches - b, s, c, h, w = x.shape - x_flat = x.view(b * s, c, h, w) - features_flat = self.backbone(x_flat) - features = features_flat.view(b, s, -1) - elif x.dim() == 4: - # Single Patch Mode + if x.dim() == 4: + # Single Patch Mode: (B, C, H, W) features = self.backbone(x).unsqueeze(1) b, s = features.shape[0], 1 else: - # Assumed pre-computed or pre-reshaped features (B, S, D) + # Pre-computed features: (B, S, D) features = x - if features.dim() == 2: - features = features.unsqueeze(1) b, s = features.shape[0], features.shape[1] # 1. Project features into latent interaction space memory = self.image_proj(features) # 1b. Inject Spatial Positional Encodings - if self.use_spatial_pe and rel_coords is not None: - # Add spatial PE to visual features + if self.use_spatial_pe: + if rel_coords is None: + raise ValueError( + "use_spatial_pe is True, but rel_coords was not provided. " + "Ensure the dataloader passes spatial coordinates." + ) pe = self.spatial_encoder(rel_coords) memory = memory + pe - # 1c. Local Patch Mixing (Conv) - if ( - hasattr(self, "early_mixer") - and self.early_mixer is not None - and rel_coords is not None - ): - # Enforce mixing of histology features before they attend to pathways - # Ensure coords are normalized to grid indices for LocalPatchMixer - grid_coords = self._normalize_coords(rel_coords) - memory = self.early_mixer(memory, grid_coords) - # 2. Retrieve learnable pathway tokens - tgt = self.pathway_tokenizer(b) + pathway_tokens = self.pathway_tokens.expand(b, -1, -1) # (B, P, D) + p = pathway_tokens.shape[1] - # 3. Process Interactions (Unified Early Fusion Path) - out = self.fusion_engine(p_tokens=tgt, h_tokens=memory) - - # 4. Project focused pathway tokens back to gene space - # We project each pathway token to a scalar activation, then use these - # activations to reconstruct gene expression levels (the bottleneck). - pathway_activations = self.pathway_activator(out).squeeze( - -1 - ) # (B, num_pathways) - gene_expression = self.gene_reconstructor(pathway_activations) # (B, num_genes) - - if return_pathways: - return gene_expression, pathway_activations - return gene_expression + # 3. Process Interactions (Standard ViT sequence: [Pathways, Patches]) + sequence = torch.cat([pathway_tokens, memory], dim=1) # (B, P + S, D) - def forward_dense(self, x, mask=None, return_pathways=False, coords=None): - """Predicts individualized gene counts for every patch in a slide. - - Optimized for dense prediction on whole slide images via global context. - - Args: - x (torch.Tensor): Precomputed features for N patches (B, N, D). - mask (torch.Tensor, optional): Boolean padding mask (B, N) where True = Padding. - return_pathways (bool): Whether to return pathway scores. - coords (torch.Tensor, optional): Absolute coordinates (B, N, 2) for PE. - - Returns: - torch.Tensor: Contextualized predictions for each patch (B, N, G). - (Optional) torch.Tensor: Pathway scores (B, N, P). - """ - if x.dim() == 2: - x = x.unsqueeze(0) + # Build attention mask from configured interactions + interaction_mask = self._build_interaction_mask(p, s, sequence.device) - b = x.shape[0] # Dynamic batch size + # If sparse/padded inputs, mask out padding so it doesn't attend + pad_mask = None + if mask is not None: + # Pad pathway tokens with False (don't ignore) + pad_mask = torch.cat( + [torch.zeros(b, p, dtype=torch.bool, device=mask.device), mask], dim=1 + ) - # 1. Project to latent space - memory = self.image_proj(x) + attentions = [] + if return_attention: + # Manual forward through fusion_engine layers to extract weights + # Standard nn.TransformerEncoder suppresses weights for performance. + x_layer = sequence + for layer in self.fusion_engine.layers: + # Multi-head attention bit + # We need to call the internal self_attn with need_weights=True + attn_output, attn_weights = layer.self_attn( + x_layer, + x_layer, + x_layer, + attn_mask=interaction_mask, + key_padding_mask=pad_mask, + need_weights=True, + ) + attentions.append(attn_weights) + + # Rest of the layer (as per nn.TransformerEncoderLayer) + if layer.norm_first: + x_layer = x_layer + layer._sa_block( + layer.norm1(x_layer), interaction_mask, pad_mask + ) + x_layer = x_layer + layer._ff_block(layer.norm2(x_layer)) + else: + x_layer = layer.norm1( + x_layer + layer._sa_block(x_layer, interaction_mask, pad_mask) + ) + x_layer = layer.norm2(x_layer + layer._ff_block(x_layer)) + out = x_layer + else: + out = self.fusion_engine( + sequence, mask=interaction_mask, src_key_padding_mask=pad_mask + ) - # 1b. Inject Spatial Positional Encodings (Global) - if self.use_spatial_pe and coords is not None: - pe = self.spatial_encoder(coords) - memory = memory + pe + # Extract focused pathway tokens + processed_pathway_tokens = out[:, :p, :] # (B, P, D) - # 1c. Local Patch Mixing (Conv) - if ( - hasattr(self, "early_mixer") - and self.early_mixer is not None - and coords is not None - ): - grid_coords = self._normalize_coords(coords) - memory = self.early_mixer(memory, grid_coords) - - # 2. Retrieve learnable pathway tokens (global context) - pathway_tokens = self.pathway_tokenizer(b) # (B, P, D) - - # 3. Perform dense interaction (Global Patch-to-Patch + Pathway context) - # The MultimodalFusion/NystromEncoder class handles the concatenation and quadrant masking. - all_tokens = self.fusion_engine( - p_tokens=pathway_tokens, h_tokens=memory, return_all_tokens=True - ) + # Extract processed patch tokens + processed_patch_tokens = out[:, p:, :] # (B, S, D) - # Sliced Histology tokens: indices [Np:] - np = pathway_tokens.shape[1] - context_features = all_tokens[:, np:, :] + # 5. Compute pathway scores via cosine similarity with learnable temperature + # L2-normalize both sets of tokens to produce cosine similarities in [-1, 1] + norm_pathway = F.normalize(processed_pathway_tokens, dim=-1) # (B, P, D) + temperature = self.log_temperature.exp() # scalar - # 3b. Local GNN Refinement - if self.late_refiner is not None and coords is not None: - # Explicit skip connection: Inject raw spatial/visual memory back into contextualized tokens - context_features = self.late_refiner( - context_features + memory, coords, mask=mask + if return_dense: + # Dense prediction: per-patch cosine similarity with pathway tokens + norm_patch = F.normalize(processed_patch_tokens, dim=-1) # (B, S, D) + # (B, S, D) @ (B, D, P) -> (B, S, P) + pathway_scores = ( + torch.matmul(norm_patch, norm_pathway.transpose(1, 2)) * temperature ) + else: + # Global prediction: pool patches first, then compute scores + global_patch_token = processed_patch_tokens.mean( + dim=1, keepdim=True + ) # (B, 1, D) + norm_global = F.normalize(global_patch_token, dim=-1) # (B, 1, D) + pathway_scores = ( + torch.matmul(norm_global, norm_pathway.transpose(1, 2)) * temperature + ) + pathway_scores = pathway_scores.squeeze(1) # (B, P) + + # Gene reconstruction (unified for both modes) + if self.output_mode == "zinb": + mu = F.softplus(self.gene_reconstructor(pathway_scores)) + 1e-6 + mu = torch.clamp(mu, max=1e5) + pi = torch.sigmoid(self.pi_reconstructor(pathway_scores)) + theta = F.softplus(self.theta_reconstructor(pathway_scores)) + 1e-6 + gene_expression = (pi, mu, theta) + else: + gene_expression = self.gene_reconstructor(pathway_scores) - # 4. Dense prediction head (Pathway Bottleneck enforcement) - # Project updated patch tokens onto pathway embeddings: (B, N, D) @ (B, D, P) -> (B, N, P) - # We scale by 1/sqrt(D) to maintain reasonable activation variance, as this is effectively attention. - pathway_scores = torch.matmul( - context_features, pathway_tokens.transpose(1, 2) - ) / (context_features.shape[-1] ** 0.5) - - # Reconstruct Genes: (B, N, P) @ (P, G) -> (B, N, G) - # Reuses the gene_reconstructor for biological consistency. + results = [gene_expression] if return_pathways: - return self.gene_reconstructor(pathway_scores), pathway_scores - return self.gene_reconstructor(pathway_scores) + results.append(pathway_scores) + if return_attention: + results.append(attentions) + + if len(results) == 1: + return results[0] + return tuple(results) diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index 1bfb65c..f8783cf 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -253,17 +253,18 @@ def plot_training_summary( if len(label) > 30: label = label[:27] + "..." - truth_raw = pathway_truth[:, pw_idx] - pred_raw = pathway_pred[:, pw_idx] + truth_vals = pathway_truth[:, pw_idx] + pred_vals = pathway_pred[:, pw_idx] - # Use absolute bounds instead of per-plot z-scored normalized bounds - vmin = min(truth_raw.min(), pred_raw.min()) - vmax = max(truth_raw.max(), pred_raw.max()) + # Both truth and pred are now in the same units (mean log1p expression + # of pathway member genes), so shared bounds give a fair comparison + vmin = min(truth_vals.min(), pred_vals.min()) + vmax = max(truth_vals.max(), pred_vals.max()) sc = None for col, vals, suffix in [ - (col_gt, truth_raw, "Truth"), - (col_pred, pred_raw, "Pred"), + (col_gt, truth_vals, "Truth"), + (col_pred, pred_vals, "Pred"), ]: ax = fig.add_subplot(inner[row, col]) ax.set_facecolor("#0d0d1a") @@ -331,11 +332,9 @@ def main(): parser.add_argument( "--n-neighbors", type=int, default=0, help="Number of spatial neighbors to use" ) - parser.add_argument( - "--use-nystrom", - action="store_true", - help="Use Nystrom attention for linear complexity", - ) + parser.add_argument("--token-dim", type=int, default=256) + parser.add_argument("--n-heads", type=int, default=4) + parser.add_argument("--n-layers", type=int, default=2) parser.add_argument( "--num-pathways", type=int, @@ -351,6 +350,13 @@ def main(): parser.add_argument( "--plot-pathways", action="store_true", help="Visualize pathway activations" ) + parser.add_argument( + "--loss", + type=str, + default="mse", + choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], + help="Loss function used for training (needed for model reconstruction)", + ) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -372,9 +378,12 @@ def main(): elif args.model_type == "interaction": model = SpatialTranscriptFormer( num_genes=args.num_genes, - use_nystrom=args.use_nystrom, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, num_pathways=args.num_pathways, backbone_name=args.backbone, + output_mode="zinb" if args.loss == "zinb" else "counts", ) state_dict = torch.load(args.model_path, map_location=device, weights_only=True) @@ -460,7 +469,7 @@ def main(): output = model( images, rel_coords=rel_coords, return_pathways=args.plot_pathways ) - if isinstance(output, tuple): + if isinstance(output, tuple) and args.plot_pathways: preds = output[0] pathways = output[1] all_pathways.append(pathways.cpu().numpy()) @@ -469,6 +478,10 @@ def main(): else: preds = model(images) + # Unpack ZINB tuple if generated + if isinstance(preds, tuple): + preds = preds[1] # Use mean component + all_preds.append(preds.cpu().numpy()) all_truth.append(targets.numpy()) diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index 3ad8783..b24c4f5 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -19,6 +19,7 @@ PCCLoss, CompositeLoss, MaskedMSELoss, + ZINBLoss, ) from spatial_transcript_former.training.engine import train_one_epoch, validate from spatial_transcript_former.training.experiment_logger import ExperimentLogger @@ -88,18 +89,14 @@ def setup_model(args, device): num_genes=args.num_genes, backbone_name=args.backbone, pretrained=args.pretrained, - use_nystrom=args.use_nystrom, - mask_radius=args.mask_radius, - masked_quadrants=args.masked_quadrants, + token_dim=args.token_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, num_pathways=args.num_pathways, pathway_init=pathway_init, use_spatial_pe=args.use_spatial_pe, - early_mixer=( - None if args.early_mixer.lower() == "none" else args.early_mixer - ), - late_refiner=( - None if args.late_refiner.lower() == "none" else args.late_refiner - ), + output_mode="zinb" if args.loss == "zinb" else "counts", + interactions=getattr(args, "interactions", None), ) elif args.model == "attention_mil": from spatial_transcript_former.models.mil import AttentionMIL @@ -133,19 +130,34 @@ def setup_model(args, device): return model -def setup_criterion(args): - """Create loss function from CLI args.""" +def setup_criterion(args, pathway_init=None): + """Create loss function from CLI args. + + If ``pathway_init`` is provided and ``--pathway-loss-weight > 0``, + wraps the base criterion with :class:`AuxiliaryPathwayLoss`. + """ if args.loss == "pcc": - return PCCLoss() + base = PCCLoss() elif args.loss == "mse_pcc": - return CompositeLoss(alpha=args.pcc_weight) + base = CompositeLoss(alpha=args.pcc_weight) + elif args.loss == "zinb": + base = ZINBLoss() elif args.loss == "poisson": - return nn.PoissonNLLLoss(log_input=True) + base = nn.PoissonNLLLoss(log_input=True) elif args.loss == "logcosh": print("Using HuberLoss as proxy for LogCosh") - return nn.HuberLoss() + base = nn.HuberLoss() else: - return MaskedMSELoss() + base = MaskedMSELoss() + + pw_weight = getattr(args, "pathway_loss_weight", 0.0) + if pathway_init is not None and pw_weight > 0: + from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss + + print(f"Wrapping criterion with AuxiliaryPathwayLoss (lambda={pw_weight})") + return AuxiliaryPathwayLoss(pathway_init, base, lambda_pathway=pw_weight) + + return base # --------------------------------------------------------------------------- @@ -248,7 +260,7 @@ def parse_args(): "--loss", type=str, default="mse", - choices=["mse", "pcc", "mse_pcc", "poisson", "logcosh"], + choices=["mse", "pcc", "mse_pcc", "zinb", "poisson", "logcosh"], ) parser.add_argument( "--pcc-weight", @@ -256,6 +268,12 @@ def parse_args(): default=1.0, help="Weight for PCC term in mse_pcc loss", ) + parser.add_argument( + "--pathway-loss-weight", + type=float, + default=0.0, + help="Weight for auxiliary pathway PCC loss (0 = disabled)", + ) # Model g = parser.add_argument_group("Model") @@ -269,26 +287,21 @@ def parse_args(): g.add_argument("--no-pretrained", action="store_false", dest="pretrained") g.set_defaults(pretrained=True) g.add_argument("--num-pathways", type=int, default=50) - g.add_argument("--use-nystrom", action="store_true") - g.add_argument("--mask-radius", type=float, default=None) + g.add_argument("--token-dim", type=int, default=256) + g.add_argument("--n-heads", type=int, default=4) + g.add_argument("--n-layers", type=int, default=2) g.add_argument( "--no-spatial-pe", action="store_false", dest="use_spatial_pe", - help="Disable Spatial Positional Encoding", - ) - g.set_defaults(use_spatial_pe=True) - g.add_argument( - "--early-mixer", - type=str, - default="conv", - help="Early spatial mixer ('conv' or 'none')", + help="Disable spatial positional encoding", ) + g.set_defaults(use_spatial_pe=False) g.add_argument( - "--late-refiner", - type=str, - default="none", - help="Late spatial refiner ('gnn' or 'none')", + "--interactions", + nargs="+", + default=None, + help="Attention interactions to enable: p2p, p2h, h2p, h2h (default: all)", ) # Training @@ -302,6 +315,7 @@ def parse_args(): "--lr", type=float, default=get_config("training.learning_rate", 1e-4) ) g.add_argument("--weight-decay", type=float, default=0.0) + g.add_argument("--warmup-epochs", type=int, default=10) g.add_argument("--sparsity-lambda", type=float, default=0.0) g.add_argument("--augment", action="store_true") g.add_argument("--use-amp", action="store_true") @@ -319,7 +333,6 @@ def parse_args(): g.add_argument("--use-global-context", action="store_true") g.add_argument("--global-context-size", type=int, default=128) g.add_argument("--compile-backend", type=str, default="inductor") - g.add_argument("--masked-quadrants", type=str, nargs="+", default=["H2H"]) g.add_argument("--plot-pathways", action="store_true") g.add_argument( "--weak-supervision", action="store_true", help="Bag-level training for MIL" @@ -373,12 +386,29 @@ def main(): # 2. Model, Loss, Optimizer model = setup_model(args, device) - criterion = setup_criterion(args) + # Pass pathway_init to criterion so AuxiliaryPathwayLoss can use it + pathway_init = getattr(model, "_pathway_init_matrix", None) + criterion = setup_criterion(args, pathway_init=pathway_init).to(device) optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) + + # LR scheduler: cosine annealing with optional linear warmup + warmup_epochs = args.warmup_epochs + cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs - warmup_epochs, eta_min=1e-6 + ) + + def lr_lambda(epoch): + if epoch < warmup_epochs: + return epoch / max(warmup_epochs, 1) + return 1.0 # cosine scheduler handles the rest + + warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + scaler = torch.amp.GradScaler("cuda") if args.use_amp else None print(f"Loss: {criterion.__class__.__name__}") + print(f"LR schedule: {warmup_epochs}-epoch warmup → cosine annealing to 1e-6") # 3. Output & Logger os.makedirs(args.output_dir, exist_ok=True) @@ -418,14 +448,28 @@ def main(): ) val_loss = val_metrics["val_loss"] - print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") + print( + f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}" + ) + + # Step LR scheduler + if epoch < warmup_epochs: + warmup_scheduler.step() + else: + cosine_scheduler.step() # Log epoch - epoch_row = {"train_loss": train_loss, "val_loss": val_loss} + epoch_row = { + "train_loss": train_loss, + "val_loss": val_loss, + "lr": optimizer.param_groups[0]["lr"], + } if val_metrics.get("val_mae") is not None: epoch_row["val_mae"] = round(val_metrics["val_mae"], 4) if val_metrics.get("val_pcc") is not None: epoch_row["val_pcc"] = round(val_metrics["val_pcc"], 4) + if val_metrics.get("pred_variance") is not None: + epoch_row["pred_variance"] = round(val_metrics["pred_variance"], 6) if val_metrics.get("attn_correlation") is not None: epoch_row["attn_correlation"] = round(val_metrics["attn_correlation"], 4) logger.log_epoch(epoch + 1, epoch_row) diff --git a/src/spatial_transcript_former/training/engine.py b/src/spatial_transcript_former/training/engine.py index ed26c79..7319b48 100644 --- a/src/spatial_transcript_former/training/engine.py +++ b/src/spatial_transcript_former/training/engine.py @@ -9,6 +9,7 @@ import torch.nn as nn from tqdm import tqdm from spatial_transcript_former.models import SpatialTranscriptFormer +from spatial_transcript_former.training.losses import AuxiliaryPathwayLoss def _optimizer_step( @@ -74,11 +75,29 @@ def train_one_epoch( ) with torch.amp.autocast("cuda", enabled=scaler is not None): - if hasattr(model, "forward_dense") and not getattr( + if isinstance(model, SpatialTranscriptFormer) and not getattr( model, "weak_supervision", False ): - preds = model.forward_dense(feats, mask=mask, coords=coords) - loss = criterion(preds, genes, mask=mask) + # Request pathway scores if criterion can use them + needs_pathways = isinstance(criterion, AuxiliaryPathwayLoss) + output = model( + feats, + return_dense=True, + mask=mask, + rel_coords=coords, + return_pathways=needs_pathways, + ) + if needs_pathways: + preds, pathway_preds = output + loss = criterion( + preds, + genes, + mask=mask, + pathway_preds=pathway_preds, + ) + else: + preds = output + loss = criterion(preds, genes, mask=mask) else: preds = model(feats) bag_target = _compute_bag_target(genes, mask) @@ -136,6 +155,7 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) running_loss = 0.0 running_mae = 0.0 pcc_list = [] + pred_var_list = [] attn_correlations = [] with torch.no_grad(): @@ -157,10 +177,21 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) attn = None if whole_slide: - if hasattr(model, "forward_dense") and not getattr( + if isinstance(model, SpatialTranscriptFormer) and not getattr( model, "weak_supervision", False ): - outputs = model.forward_dense(feats, mask=mask, coords=coords) + needs_pathways = isinstance(criterion, AuxiliaryPathwayLoss) + output = model( + feats, + return_dense=True, + mask=mask, + rel_coords=coords, + return_pathways=needs_pathways, + ) + if needs_pathways: + outputs, pathway_preds = output + else: + outputs = output targets = genes else: # MIL models: extract attention if supported @@ -179,33 +210,50 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) else: outputs = model(images) - loss = ( - criterion(outputs, targets, mask=mask) - if whole_slide - and hasattr(model, "forward_dense") + # Compute loss, passing pathway_preds if available + if ( + whole_slide + and isinstance(model, SpatialTranscriptFormer) and not getattr(model, "weak_supervision", False) - else criterion(outputs, targets) - ) + ): + if isinstance(criterion, AuxiliaryPathwayLoss): + loss = criterion( + outputs, + targets, + mask=mask, + pathway_preds=pathway_preds, + ) + else: + loss = criterion(outputs, targets, mask=mask) + else: + loss = criterion(outputs, targets) # --- Interpretability Metrics (MAE & PCC) --- - mae_diff = torch.abs(outputs - targets) + # Since ZINB loss outputs a tuple (pi, mu, theta), we only use mu (index 1) for evaluations against truth. + eval_preds = outputs[1] if isinstance(outputs, tuple) else outputs + mae_diff = torch.abs(eval_preds - targets) if ( whole_slide - and hasattr(model, "forward_dense") + and isinstance(model, SpatialTranscriptFormer) and not getattr(model, "weak_supervision", False) and mask is not None ): valid_mask = ~mask.unsqueeze(-1).expand_as(mae_diff) mae_val = (mae_diff * valid_mask.float()).sum() / valid_mask.sum() + else: + mae_val = mae_diff.mean() - if torch.isfinite(outputs).all() and torch.isfinite(targets).all(): + if ( + torch.isfinite(eval_preds).all() + and torch.isfinite(targets).all() + ): # Calculate Spatial PCC (across spots N, for each gene G independently) # outputs/targets are (B, N, G) for whole_slide or (B, G) for patch if whole_slide: # Iterate over batches to correlate spatially for each slide - B = outputs.shape[0] + B = eval_preds.shape[0] for b_idx in range(B): - p_slide = outputs[b_idx] # (N, G) + p_slide = eval_preds[b_idx] # (N, G) t_slide = targets[b_idx] # (N, G) valid_idx = ~mask[b_idx] @@ -231,7 +279,7 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) pcc_list.append(valid_corrs.mean().item()) else: # Patch level (B, G). Correlate across the batch B (which is spatial patches) - vx = outputs - outputs.mean(dim=0, keepdim=True) + vx = eval_preds - eval_preds.mean(dim=0, keepdim=True) vy = targets - targets.mean(dim=0, keepdim=True) num = torch.sum(vx * vy, dim=0) den = torch.sqrt( @@ -261,6 +309,23 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) running_loss += loss.item() running_mae += mae_val.item() + # Track prediction variance (collapse detector) + with torch.no_grad(): + if ( + whole_slide + and mask is not None + and isinstance(model, SpatialTranscriptFormer) + and not getattr(model, "weak_supervision", False) + ): + for b in range(eval_preds.shape[0]): + valid = ~mask[b] + if valid.sum() >= 2: + pred_var_list.append( + eval_preds[b, valid].var(dim=0).mean().item() + ) + else: + pred_var_list.append(eval_preds.var(dim=0).mean().item()) + avg_loss = running_loss / len(loader) avg_mae = running_mae / len(loader) avg_pcc = sum(pcc_list) / len(pcc_list) if pcc_list else None @@ -268,8 +333,12 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) sum(attn_correlations) / len(attn_correlations) if attn_correlations else None ) + avg_pred_var = sum(pred_var_list) / len(pred_var_list) if pred_var_list else None + if avg_pcc is not None: print(f"Validation MAE: {avg_mae:.4f} | PCC: {avg_pcc:.4f}") + if avg_pred_var is not None: + print(f"Prediction Variance: {avg_pred_var:.6f}") if avg_corr is not None: print(f"Spatial Attention Correlation: {avg_corr:.4f}") @@ -277,5 +346,6 @@ def validate(model, loader, criterion, device, whole_slide=False, use_amp=False) "val_loss": avg_loss, "val_mae": avg_mae, "val_pcc": avg_pcc, + "pred_variance": avg_pred_var, "attn_correlation": avg_corr, } diff --git a/src/spatial_transcript_former/training/losses.py b/src/spatial_transcript_former/training/losses.py index ffdfae0..e9a3bc2 100644 --- a/src/spatial_transcript_former/training/losses.py +++ b/src/spatial_transcript_former/training/losses.py @@ -37,25 +37,69 @@ def forward(self, preds, target, mask=None): Returns: Scalar loss = 1 - mean(PCC). """ - if preds.dim() == 3: - B, N, G = preds.shape - preds = preds.reshape(-1, G) # (B*N, G) - target = target.reshape(-1, G) + if preds.dim() == 2: + preds = preds.unsqueeze(1) # (B, 1, G) + target = target.unsqueeze(1) # (B, 1, G) + if mask is not None: + mask = mask.unsqueeze(1) # (B, 1) + + B, N, G = preds.shape + # If N == 1 (e.g., standard patch-wise prediction without context), + # PCC across a spatial dimension of 1 is undefined (variance is 0). + # We fallback to batch-wise correlation in this specific edge case. + if N == 1: + preds = preds.squeeze(1) # (B, G) + target = target.squeeze(1) # (B, G) if mask is not None: - valid = ~mask.reshape(-1) # (B*N,) + valid = ~mask.squeeze(1) # (B) preds = preds[valid] target = target[valid] - # Centre per-spot - vx = preds - preds.mean(dim=0, keepdim=True) - vy = target - target.mean(dim=0, keepdim=True) + if preds.shape[0] < 2: + return torch.tensor(0.0, device=preds.device, requires_grad=True) + + vx = preds - preds.mean(dim=0, keepdim=True) + vy = target - target.mean(dim=0, keepdim=True) + cost = torch.sum(vx * vy, dim=0) / ( + torch.sqrt(torch.sum(vx**2, dim=0) + self.eps) + * torch.sqrt(torch.sum(vy**2, dim=0) + self.eps) + ) + return 1 - cost.mean() + + # 1. Masking: Zero out padded positions so they don't contribute to sums + if mask is not None: + valid = ~mask.unsqueeze(-1) # (B, N, 1) + preds = preds * valid.float() + target = target * valid.float() + # valid_counts: (B, 1, 1) to enable broadcasting across N (dim 1) and G (dim 2) + valid_counts = valid.sum(dim=1, keepdim=True).clamp(min=1.0) + else: + valid_counts = torch.tensor(N, dtype=torch.float32, device=preds.device) - # Correlation - cost = torch.sum(vx * vy, dim=0) / ( - torch.sqrt(torch.sum(vx**2, dim=0) + self.eps) - * torch.sqrt(torch.sum(vy**2, dim=0) + self.eps) - ) + # 2. Centre per-slide (over spatial dimension N) + # Calculate mean using valid counts to avoid skew from zeros + pred_means = preds.sum(dim=1, keepdim=True) / valid_counts + target_means = target.sum(dim=1, keepdim=True) / valid_counts + + vx = preds - pred_means + vy = target - target_means + + if mask is not None: + vx = vx * valid.float() + vy = vy * valid.float() + + # 3. Correlation (covariance / (std_x * std_y)) per slide, per gene + # sum over spatial dimension N -> (B, G) + cov = torch.sum(vx * vy, dim=1) + var_x = torch.sum(vx**2, dim=1) + var_y = torch.sum(vy**2, dim=1) + + cost = cov / ( + torch.sqrt(var_x + self.eps) * torch.sqrt(var_y + self.eps) + ) # (B, G) + + # Average across genes, then average across batch return 1 - cost.mean() @@ -143,3 +187,151 @@ def forward(self, preds, target, mask=None): return (loss * valid.float()).sum() / valid.sum() return loss.mean() + + +class ZINBLoss(nn.Module): + """ + Zero-Inflated Negative Binomial (ZINB) Loss. + + Designed for highly dispersed, zero-inflated count data (like raw RNA-seq). + Requires the model to output three parameters per gene: + - pi: probability of zero inflation (dropout) + - mu: mean of the negative binomial distribution + - theta: inverse dispersion parameter + """ + + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, preds, target, mask=None): + """ + Args: + preds: Tuple of (pi, mu, theta), each (B, G) or (B, N, G). + pi should be [0, 1] (e.g. via Sigmoid). + mu, theta should be > 0 (e.g. via Softplus or Exp). + target: Raw integer counts (B, G) or (B, N, G). + mask: (B, N) boolean, True = padded (ignore). Optional. + + Returns: + Scalar negative log-likelihood over valid positions. + """ + pi, mu, theta = preds + + # Add numerical stability to theta and mu + theta = torch.clamp(theta, min=self.eps, max=1e6) + mu = torch.clamp(mu, min=self.eps, max=1e6) + + # Ensure target is valid for ZINB (no negatives, we expect counts) + target = torch.clamp(target, min=0) + + # 1. Negative Binomial Probability + # NB(y; mu, theta) = Gamma(y+theta)/(Gamma(theta)*y!) * (theta/(theta+mu))^theta * (mu/(theta+mu))^y + # Log version for stability using lgamma: + # ln(Gamma(target + theta)) - ln(Gamma(theta)) - ln(Gamma(target + 1)) + # + theta * ln(theta / (theta + mu)) + target * ln(mu / (theta + mu)) + + t1 = ( + torch.lgamma(target + theta) + - torch.lgamma(theta) + - torch.lgamma(target + 1) + ) + t2 = theta * (torch.log(theta + self.eps) - torch.log(theta + mu + self.eps)) + t3 = target * (torch.log(mu + self.eps) - torch.log(theta + mu + self.eps)) + + nb_log_prob = t1 + t2 + t3 + + # 2. Zero Inflation Mask + is_zero = (target == 0).float() + + # 3. ZINB Log Likelihood + # 3a. If target == 0: ln(pi + (1-pi) * NB(0; mu, theta)) + # NB(0; mu, theta) = (theta / (theta + mu))^theta + # Log space for stability: theta * (log(theta) - log(theta + mu)) + nb_zero_log_prob = theta * ( + torch.log(theta + self.eps) - torch.log(theta + mu + self.eps) + ) + + # zero_case_prob = pi + (1 - pi) * exp(nb_zero_log_prob) + # We need ln(pi + (1-pi)*exp(nb_zero_log_prob)). + # clamp heavily before taking log to prevent NaNs when pi -> 0 and exp(nb) -> 0 + zero_case_prob = pi + (1 - pi) * torch.exp(nb_zero_log_prob) + zero_case_prob = torch.clamp(zero_case_prob, min=self.eps, max=1.0) + zero_case_log_prob = torch.log(zero_case_prob) + + # 3b. If target > 0: ln((1-pi) * NB(target; mu, theta)) + # = ln(1-pi) + ln(NB) + non_zero_case_log_prob = torch.log(1 - pi + self.eps) + nb_log_prob + + # Combine + log_likelihood = ( + is_zero * zero_case_log_prob + (1 - is_zero) * non_zero_case_log_prob + ) + + if mask is not None and pi.dim() == 3: + # Expand mask to gene dimension: (B, N) -> (B, N, G) + valid = ~mask.unsqueeze(-1).expand_as(log_likelihood) + return -(log_likelihood * valid.float()).sum() / valid.sum() + + return -log_likelihood.mean() + + +class AuxiliaryPathwayLoss(nn.Module): + """Wrapper combining gene-level loss with PCC-based pathway auxiliary loss. + + Directly supervises the pathway bottleneck scores using PCC against + pathway ground truth computed from gene expression via MSigDB membership. + This provides a direct gradient signal to the pathway tokens, preventing + bottleneck collapse and ensuring each pathway learns its correct spatial + activation pattern. + + The total loss is:: + + L = L_gene + lambda * (1 - PCC(pathway_scores, target_pathways)) + + where ``target_pathways = target_genes @ M.T`` and M is the MSigDB + membership matrix. + """ + + def __init__(self, pathway_matrix, gene_criterion, lambda_pathway=0.5): + """ + Args: + pathway_matrix (torch.Tensor): Binary MSigDB membership matrix + of shape ``(P, G)`` where ``P`` is the number of pathways and + ``G`` is the number of genes. + gene_criterion (nn.Module): Base loss for gene prediction + (e.g. ``MaskedMSELoss``, ``CompositeLoss``). + lambda_pathway (float): Weight for the auxiliary pathway loss. + """ + super().__init__() + self.register_buffer("pathway_matrix", pathway_matrix.float()) + self.gene_criterion = gene_criterion + self.lambda_pathway = lambda_pathway + self.pcc = PCCLoss() + + def forward(self, gene_preds, target_genes, mask=None, pathway_preds=None): + """ + Args: + gene_preds: (B, G) or (B, N, G) predicted gene expression. + target_genes: (B, G) or (B, N, G) ground truth gene expression. + mask: (B, N) boolean, True = padded (ignore). Optional. + pathway_preds: (B, P) or (B, N, P) predicted pathway scores. + If None, only gene loss is computed (graceful fallback). + + Returns: + Scalar total loss. + """ + gene_loss = self.gene_criterion(gene_preds, target_genes, mask=mask) + + if pathway_preds is None or self.lambda_pathway == 0: + return gene_loss + + # Compute pathway ground truth from gene expression + # target_genes: (B, [N,] G), pathway_matrix: (P, G) + # result: (B, [N,] P) + with torch.no_grad(): + target_pathways = torch.matmul(target_genes, self.pathway_matrix.T) + + pathway_loss = self.pcc(pathway_preds, target_pathways, mask=mask) + + return gene_loss + self.lambda_pathway * pathway_loss diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index e7b7a39..1e3b367 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -152,23 +152,12 @@ def run_inference_plot(model, args, sample_id, epoch, device): h5_path = os.path.join(patches_dir, f"{sample_id}.h5") h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad") - with h5py.File(h5_path, "r") as f: - patch_barcodes = f["barcode"][:].flatten() - coords = f["coords"][:] - # Load global genes try: common_gene_names = load_global_genes(args.data_dir, args.num_genes) except Exception: common_gene_names = None - gene_matrix, mask, gene_names = load_gene_expression_matrix( - h5ad_path, - patch_barcodes, - selected_gene_names=common_gene_names, - num_genes=args.num_genes, - ) - # Run inference preds = [] pathways_list = [] @@ -194,18 +183,53 @@ def run_inference_plot(model, args, sample_id, epoch, device): selected_gene_names=common_gene_names, n_neighbors=args.n_neighbors, whole_slide_mode=args.whole_slide, + log1p=log_transform, ) if args.whole_slide: - feats, _, _ = ds[0] + feats, gene_targets, slide_coords = ds[0] feats = feats.unsqueeze(0).to(device) - if hasattr(model, "forward_dense"): - output = model.forward_dense(feats, return_pathways=True) + slide_coords = slide_coords.unsqueeze(0).to(device) + if isinstance(model, SpatialTranscriptFormer): + output = model( + feats, + return_dense=True, + rel_coords=slide_coords, + return_pathways=True, + ) if isinstance(output, tuple): - preds.append(output[0].detach().cpu().squeeze(0)) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.detach().cpu().squeeze(0)) pathways_list.append(output[1].detach().cpu().squeeze(0)) else: preds.append(output.detach().cpu().squeeze(0)) + + # Use raw pixel coords from the .pt file for histology overlay + # These are guaranteed to be in the same order as the features + saved_data = torch.load( + feature_path, map_location="cpu", weights_only=True + ) + raw_coords = saved_data["coords"] # (N, 2) + raw_barcodes = saved_data["barcodes"] + del saved_data + + # Compute the same mask the dataset used to filter + _, pt_mask, gene_names = load_gene_expression_matrix( + h5ad_path, + raw_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + pt_mask_bool = np.array(pt_mask, dtype=bool) + coord_subset = raw_coords[pt_mask_bool].numpy() + + # Pathway truth from the dataset's aligned gene matrix + gene_truth = gene_targets.numpy() + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) else: dl = DataLoader(ds, batch_size=32, shuffle=False) for feats, _, rel_coords_batch in dl: @@ -217,13 +241,43 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) if isinstance(output, tuple): pathways_list.append(output[1].cpu()) - preds.append(output[0].cpu()) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.cpu()) else: preds.append(output.cpu()) else: preds.append(model(feats.to(device)).cpu()) + + # Non-whole-slide: use h5 file coords (same source as DataLoader) + with h5py.File(h5_path, "r") as f: + patch_barcodes = f["barcode"][:].flatten() + h5_coords = f["coords"][:] + gene_matrix, mask, gene_names = load_gene_expression_matrix( + h5ad_path, + patch_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + coord_mask = np.array(mask, dtype=bool) + coord_subset = h5_coords[coord_mask] + gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) else: - coord_subset = coords[mask] + with h5py.File(h5_path, "r") as f: + patch_barcodes = f["barcode"][:].flatten() + h5_coords = f["coords"][:] + gene_matrix, mask, gene_names = load_gene_expression_matrix( + h5ad_path, + patch_barcodes, + selected_gene_names=common_gene_names, + num_genes=args.num_genes, + ) + coord_mask = np.array(mask, dtype=bool) + coord_subset = h5_coords[coord_mask] ds = HEST_Dataset( h5_path, coord_subset, gene_matrix, indices=np.where(mask)[0] ) @@ -237,28 +291,32 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) if isinstance(output, tuple): pathways_list.append(output[1].cpu()) - preds.append(output[0].cpu()) + out_preds = output[0] + if isinstance(out_preds, tuple): + out_preds = out_preds[1] + preds.append(out_preds.cpu()) else: preds.append(output.cpu()) else: preds.append(model(imgs.to(device)).cpu()) + gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix + pathway_truth, pathway_names = _compute_pathway_truth( + gene_truth, gene_names, args=args + ) - if not preds or not pathways_list: - print("Warning: No predictions/pathways generated. Skipping plot.") + if not preds: + print("Warning: No predictions generated. Skipping plot.") return - pathways = torch.cat(pathways_list, dim=0).numpy() - coord_mask = np.array(mask, dtype=bool) - coord_subset = coords[coord_mask] - - # Compute pathway ground truth from MSigDB membership (fixed across epochs) - gene_truth = np.log1p(gene_matrix) if log_transform else gene_matrix - pathway_truth, pathway_names = _compute_pathway_truth( - gene_truth, gene_names, args=args + # Compute pathway activations from gene predictions (same method as truth) + # Both truth and pred are now: mean gene expression of pathway members + gene_preds_np = torch.cat(preds, dim=0).numpy() + pathway_pred, _ = _compute_pathway_truth( + gene_preds_np, gene_names, args=args ) - if pathway_truth is None: - print("Warning: Could not compute pathway truth. Skipping plot.") + if pathway_truth is None or pathway_pred is None: + print("Warning: Could not compute pathway truth/pred. Skipping plot.") return # Load histology image @@ -270,7 +328,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): ) plot_training_summary( coord_subset, - pathways, + pathway_pred, pathway_truth, pathway_names, sample_id=sample_id, diff --git a/tests/test_bottleneck_arch.py b/tests/test_bottleneck_arch.py index a2effba..f2eea5e 100644 --- a/tests/test_bottleneck_arch.py +++ b/tests/test_bottleneck_arch.py @@ -16,9 +16,11 @@ def test_interaction_output_shape(): # Dummy input (Batch, Channel, Height, Width) x = torch.randn(4, 3, 224, 224) + # Single patch => S=1 + rel_coords = torch.randn(4, 1, 2) # Forward pass - output = model(x) + output = model(x, rel_coords=rel_coords) # Verify shape (Batch, num_genes) assert output.shape == ( @@ -32,8 +34,9 @@ def test_interaction_gradient_flow(): num_genes = 1000 model = SpatialTranscriptFormer(num_genes=num_genes) x = torch.randn(2, 3, 224, 224) + rel_coords = torch.randn(2, 1, 2) - output = model(x) + output = model(x, rel_coords=rel_coords) loss = output.sum() loss.backward() diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 1e86a27..08fe63b 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -27,22 +27,7 @@ def small_model(): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, - ) - - -@pytest.fixture -def small_nystrom_model(): - """A small SpatialTranscriptFormer with Nystrom attention.""" - return SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=True, - num_landmarks=32, + n_layers=2, ) @@ -85,8 +70,7 @@ def test_save_load_preserves_weights(self, small_model, checkpoint_dir): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, + n_layers=2, ) fresh_optimizer = optim.Adam(fresh_model.parameters(), lr=1e-4) @@ -125,8 +109,7 @@ def test_save_load_preserves_scaler(self, small_model, checkpoint_dir): num_pathways=10, token_dim=64, n_heads=4, - n_layers=1, - use_nystrom=False, + n_layers=2, ) fresh_optimizer = optim.Adam(fresh_model.parameters(), lr=1e-4) fresh_scaler = torch.amp.GradScaler("cuda") @@ -151,75 +134,3 @@ def test_no_checkpoint_starts_fresh(self, small_model, checkpoint_dir): ) assert start_epoch == 0 assert best_val == float("inf") - - -# --------------------------------------------------------------------------- -# Architecture Mismatch -# --------------------------------------------------------------------------- - - -class TestArchitectureMismatch: - def test_nystrom_checkpoint_fails_on_standard( - self, small_nystrom_model, checkpoint_dir - ): - """Nystrom checkpoint should fail to load into standard model.""" - optimizer = optim.Adam(small_nystrom_model.parameters(), lr=1e-4) - save_checkpoint( - small_nystrom_model, - optimizer, - None, - epoch=5, - best_val_loss=0.3, - output_dir=checkpoint_dir, - model_name="interaction", - ) - - # Try loading into non-Nystrom model - standard_model = SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=False, - ) - standard_optimizer = optim.Adam(standard_model.parameters(), lr=1e-4) - - with pytest.raises(RuntimeError): - load_checkpoint( - standard_model, - standard_optimizer, - None, - checkpoint_dir, - "interaction", - "cpu", - ) - - def test_matching_architecture_loads(self, small_nystrom_model, checkpoint_dir): - """Nystrom checkpoint should load into matching Nystrom model.""" - optimizer = optim.Adam(small_nystrom_model.parameters(), lr=1e-4) - save_checkpoint( - small_nystrom_model, - optimizer, - None, - epoch=5, - best_val_loss=0.3, - output_dir=checkpoint_dir, - model_name="interaction", - ) - - fresh = SpatialTranscriptFormer( - num_genes=100, - num_pathways=10, - token_dim=64, - n_heads=4, - n_layers=1, - use_nystrom=True, - num_landmarks=32, - ) - fresh_opt = optim.Adam(fresh.parameters(), lr=1e-4) - - start_epoch, _ = load_checkpoint( - fresh, fresh_opt, None, checkpoint_dir, "interaction", "cpu" - ) - assert start_epoch == 6 # Resumed correctly diff --git a/tests/test_interactions.py b/tests/test_interactions.py new file mode 100644 index 0000000..6847538 --- /dev/null +++ b/tests/test_interactions.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import sys +import os +import numpy as np + +# Add src to path +sys.path.append(os.path.abspath("src")) + +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer + + +def test_mask_logic(): + print("Testing Mask Logic...") + model = SpatialTranscriptFormer( + num_genes=100, num_pathways=10, interactions=["p2p", "p2h"] + ) + + # p=10, s=20 + mask = model._build_interaction_mask(10, 20, "cpu") + + # Check p2p (allowed) + assert mask[0, 1] == False, "p2p should be allowed" + # Check p2h (allowed) + assert mask[0, 15] == False, "p2h should be allowed" + # Check h2p (blocked) + assert mask[15, 0] == True, "h2p should be blocked" + # Check h2h (blocked) + assert mask[15, 16] == True, "h2h should be blocked" + + print("Mask logic test passed!") + + +def test_connectivity(): + print("Testing Patch-to-Patch Connectivity (h2h)...") + # Enable all interactions including h2h + model = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p", "p2h", "h2p", "h2h"], + use_spatial_pe=False, # Disable PE to verify raw attention connectivity + pretrained=False, + ) + model.eval() + + # dummy features (B=1, S=2, D=2048) + feats = torch.randn(1, 2, 2048, requires_grad=True) + coords = torch.randn(1, 2, 2) + + # We want to see if the output for patch 0 depends on the input of patch 1 + # We use return_dense=True to get per-patch gene predictions + output = model(feats, rel_coords=coords, return_dense=True) # (1, 2, 100) + + # Loss on patch 0 output + loss = output[0, 0].sum() + loss.backward() + + # Check if gradient flows to patch 1 input + grad_patch_1 = feats.grad[0, 1].norm() + print(f"Gradient at Patch 1 from Patch 0 output: {grad_patch_1.item():.6e}") + + assert ( + grad_patch_1 > 0 + ), "Patch 0 output should depend on Patch 1 input when h2h is enabled" + + # Now try with h2h disabled + print("Testing Connectivity with h2h disabled...") + model_no_h2h = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p", "p2h", "h2p"], + use_spatial_pe=False, + pretrained=False, + ) + model_no_h2h.eval() + + feats_2 = torch.randn(1, 2, 2048, requires_grad=True) + output_2 = model_no_h2h(feats_2, rel_coords=coords, return_dense=True) + + loss_2 = output_2[0, 0].sum() + loss_2.backward() + + grad_patch_1_no_h2h = feats_2.grad[0, 1].norm() + print( + f"Gradient at Patch 1 from Patch 0 output (no h2h): {grad_patch_1_no_h2h.item():.6e}" + ) + + # It should still be non-zero because patches interact via pathways [Patch 1 -> Pathway -> Patch 0] + assert ( + grad_patch_1_no_h2h > 0 + ), "Patch 0 should still interact with Patch 1 via pathways even if h2h is disabled" + + # To truly see zero interaction, block pathways too + print("Testing ZERO Connectivity (only p2p enabled)...") + model_isolated = SpatialTranscriptFormer( + num_genes=100, + num_pathways=10, + n_layers=2, + interactions=["p2p"], + use_spatial_pe=False, + pretrained=False, + ) + model_isolated.eval() + feats_3 = torch.randn(1, 2, 2048, requires_grad=True) + output_3 = model_isolated(feats_3, rel_coords=coords, return_dense=True) + loss_3 = output_3[0, 0].sum() + loss_3.backward() + grad_patch_1_isolated = feats_3.grad[0, 1].norm() + print(f"Gradient at Patch 1 (fully isolated): {grad_patch_1_isolated.item():.6e}") + assert ( + grad_patch_1_isolated < 1e-10 + ), "Patch 0 should NOT depend on Patch 1 when only p2p is enabled" + + print("Connectivity tests passed!") + + +def test_attention_extraction(): + print("Testing Attention Extraction...") + p, s = 10, 20 + model = SpatialTranscriptFormer( + num_genes=100, + num_pathways=p, + n_layers=2, + interactions=["p2p", "p2h"], # Block h2p, h2h + pretrained=False, + ) + model.eval() + + feats = torch.randn(1, s, 2048) + coords = torch.randn(1, s, 2) + + # Forward with attention + _, attentions = model(feats, rel_coords=coords, return_attention=True) + + # attentions is list of weights [layers] + for i, attn in enumerate(attentions): + print(f"Testing Layer {i}...") + # attn is (B, T, T) + assert attn.shape == (1, p + s, p + s) + + # We expect blocked regions to have 0 attention + h2p_region = attn[0, p:, :p] + h2h_region = attn[0, p:, p:] + + # For h2h, we must ignore diagonal + h2h_off_diag = h2h_region.clone() + h2h_off_diag.fill_diagonal_(0) + + print(f"Layer {i} h2p attention max: {h2p_region.max().item():.2e}") + print(f"Layer {i} h2h off-diag attention max: {h2h_off_diag.max().item():.2e}") + + assert ( + h2p_region.max() < 1e-10 + ), f"Layer {i} h2p attention should be zero when blocked" + assert ( + h2h_off_diag.max() < 1e-10 + ), f"Layer {i} h2h attention should be zero when blocked" + + # Check that allowed regions have non-zero attention + p2p_region = attn[0, :p, :p] + p2h_region = attn[0, :p, p:] + print(f"Layer {i} p2p attention max: {p2p_region.max().item():.2e}") + print(f"Layer {i} p2h attention max: {p2h_region.max().item():.2e}") + + assert p2p_region.max() > 0, f"Layer {i} p2p attention should be non-zero" + assert p2h_region.max() > 0, f"Layer {i} p2h attention should be non-zero" + + print("Attention extraction test passed!") + + +if __name__ == "__main__": + test_mask_logic() + test_connectivity() + test_attention_extraction() diff --git a/tests/test_local_patch_mixer.py b/tests/test_local_patch_mixer.py deleted file mode 100644 index 20a5767..0000000 --- a/tests/test_local_patch_mixer.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import pytest -from spatial_transcript_former.models.interaction import LocalPatchMixer - - -def test_local_patch_mixer_basic(): - """Test that output shape is correct and forward pass works.""" - B, N, D = 2, 10, 32 - mixer = LocalPatchMixer(dim=D, kernel_size=3) - - # Grid of coordinates (2x5) - coords = [] - for b in range(B): - bc = [] - for y in range(2): - for x in range(5): - bc.append([x, y]) - coords.append(bc) - coords = torch.tensor(coords, dtype=torch.float32) - - x = torch.randn(B, N, D) - out = mixer(x, coords) - - assert out.shape == x.shape - # Check that residual connection worked (output is not zero) - assert not torch.allclose(out, torch.zeros_like(out)) - - -def test_local_patch_mixer_sparse(): - """Test that gaps in coordinates are handled.""" - dim = 8 - mixer = LocalPatchMixer(dim=dim, kernel_size=3) - - # 3 patches in a line with a gap: (0,0), (2,0), (4,0) - x = torch.zeros(1, 3, dim) - x[0, 0, 0] = 1.0 - x[0, 1, 1] = 1.0 - x[0, 2, 2] = 1.0 - - coords = torch.tensor([[[0, 0], [2, 0], [4, 0]]], dtype=torch.float32) - - with torch.no_grad(): - mixer.conv.weight.fill_(0.0) - mixer.conv.bias.fill_(0.0) - # Set weight to look at neighbor at x+1 (weight[1, 2] in 3x3) - mixer.conv.weight[:, 0, 1, 2] = 1.0 - - out = mixer(x, coords) - - # The patch at 2.0 looks at 3.0 (empty). Conv result should be 0. - # Out = x + 0 = x. - # Check that ALL values are exactly as original x - torch.testing.assert_close(out, x, atol=1e-4, rtol=1e-4) - - -def test_local_patch_mixer_neighbor_influence(): - """Verify that neighbors actually influence the center patch.""" - dim = 4 - mixer = LocalPatchMixer(dim=dim, kernel_size=3) - - # 3 patches adjacent: (0,0), (1,0), (2,0) - x = torch.zeros(1, 3, dim) - # Patch 1 (center) has signal in channel 0 - x[0, 1, 0] = 1.0 - - coords = torch.tensor([[[0, 0], [1, 0], [2, 0]]], dtype=torch.float32) - - with torch.no_grad(): - mixer.conv.weight.fill_(0.0) - mixer.conv.bias.fill_(0.0) - # Weight[1, 0] looks at neighbor x-1 - # So output at Grid(1,0) looks at Grid(0,0) - # Output at Grid(2,0) looks at Grid(1,0) - mixer.conv.weight[:, 0, 1, 0] = 10.0 - - out = mixer(x, coords) - - # Patch 2 (at 2,0) should receive signal from Patch 1 (at 1,0) - # Patch 1 had 1.0. Conv result at 2,0 should be 1.0 * 10.0 = 10.0. - # After GELU: GELU(10) is approx 10.0. - # Residual: 0.0 + 10.0 = 10.0. - assert out[0, 2, 0] > 9.0 - - # Patch 1 (at 1,0) should receive signal from Patch 0 (at 0,0) which is 0. - # Residual: 1.0 + 0.0 = 1.0. - assert out[0, 1, 0] == 1.0 - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_losses.py b/tests/test_losses.py index 8351286..e9ffbeb 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -1,5 +1,6 @@ """ -Tests for loss functions: MaskedMSELoss, PCCLoss, CompositeLoss. +Tests for loss functions: MaskedMSELoss, PCCLoss, CompositeLoss, +AuxiliaryPathwayLoss. Verifies correctness of masking, scale invariance, gradients, and composite objective decomposition. @@ -12,6 +13,7 @@ MaskedMSELoss, PCCLoss, CompositeLoss, + AuxiliaryPathwayLoss, ) # --------------------------------------------------------------------------- @@ -184,3 +186,294 @@ def test_different_alphas(self, tensors_2d): loss_high = CompositeLoss(alpha=10.0)(preds, target) # They should differ since PCC != 0 assert loss_low.item() != pytest.approx(loss_high.item(), abs=0.01) + + +# --------------------------------------------------------------------------- +# ZINBLoss +# --------------------------------------------------------------------------- + + +@pytest.fixture +def zinb_tensors(): + """Tensors for ZINB testing: pi, mu, theta, and counts.""" + torch.manual_seed(42) + B, G = 16, 50 + pi = torch.rand(B, G) # [0, 1] + mu = torch.rand(B, G) * 10 + 0.1 # > 0 + theta = torch.rand(B, G) * 5 + 0.1 # > 0 + + # Simulate some zero-inflated counts + counts = torch.poisson(mu) + dropout_mask = torch.rand(B, G) < 0.3 + counts[dropout_mask] = 0.0 + + return (pi, mu, theta), counts + + +class TestZINBLoss: + def test_basic_computation(self, zinb_tensors): + """ZINBLoss should compute a finite scalar loss.""" + from spatial_transcript_former.training.losses import ZINBLoss + + preds, target = zinb_tensors + loss_fn = ZINBLoss() + loss = loss_fn(preds, target) + + assert loss.isfinite() + assert loss.item() > 0 # NLL should be positive + + def test_gradient_flow(self, zinb_tensors): + """Gradients should flow through all three parameters.""" + from spatial_transcript_former.training.losses import ZINBLoss + + pi, mu, theta = zinb_tensors[0] + pi = pi.clone().requires_grad_(True) + mu = mu.clone().requires_grad_(True) + theta = theta.clone().requires_grad_(True) + + target = zinb_tensors[1] + loss = ZINBLoss()((pi, mu, theta), target) + loss.backward() + + assert pi.grad is not None + assert mu.grad is not None + assert theta.grad is not None + assert not torch.isnan(pi.grad).any() + assert not torch.isnan(mu.grad).any() + assert not torch.isnan(theta.grad).any() + + def test_mask_support(self): + """ZINBLoss should handle masks in 3D mode correctly.""" + from spatial_transcript_former.training.losses import ZINBLoss + + torch.manual_seed(42) + B, N, G = 2, 10, 5 + pi = torch.rand(B, N, G) + mu = torch.rand(B, N, G) * 5 + 0.1 + theta = torch.rand(B, N, G) * 5 + 0.1 + target = torch.randint(0, 10, (B, N, G)).float() + + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 5:] = True # sample 0 half padded + + pi.requires_grad_(True) + + loss_fn = ZINBLoss() + loss = loss_fn((pi, mu, theta), target, mask=mask) + loss.backward() + + assert loss.isfinite() + + # Gradients in padded regions should be zero + padded_grad = pi.grad[0, 5:, :] + assert padded_grad.abs().sum() == 0.0 + + +# --------------------------------------------------------------------------- +# AuxiliaryPathwayLoss +# --------------------------------------------------------------------------- + + +@pytest.fixture +def pathway_tensors(): + """Tensors for AuxiliaryPathwayLoss testing.""" + torch.manual_seed(42) + B, N, G, P = 2, 100, 50, 10 + gene_preds = torch.randn(B, N, G) + target_genes = torch.randn(B, N, G).abs() # Positive counts + pathway_preds = torch.randn(B, N, P) + # Binary MSigDB-like matrix: (P, G) + pathway_matrix = (torch.rand(P, G) > 0.8).float() + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 80:] = True + return gene_preds, target_genes, pathway_preds, pathway_matrix, mask + + +class TestAuxiliaryPathwayLoss: + def test_basic_computation(self, pathway_tensors): + """AuxiliaryPathwayLoss should produce a finite scalar.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=0.5) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + assert loss.isfinite() + + def test_includes_gene_loss(self, pathway_tensors): + """Total loss should be >= gene loss alone.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + gene_only = base(gene_preds, targets, mask=mask) + total = aux(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + # Pathway PCC loss is >= 0, so total >= gene_only + assert total.item() >= gene_only.item() - 1e-5 + + def test_gradient_flows_to_both(self, pathway_tensors): + """Gradients should flow to both gene_preds and pathway_preds.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + gene_preds = gene_preds.clone().requires_grad_(True) + pw_preds = pw_preds.clone().requires_grad_(True) + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=0.5) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + loss.backward() + assert gene_preds.grad is not None + assert pw_preds.grad is not None + assert gene_preds.grad.abs().sum() > 0 + assert pw_preds.grad.abs().sum() > 0 + + def test_fallback_without_pathways(self, pathway_tensors): + """When pathway_preds is None, should fall back to gene loss only.""" + gene_preds, targets, _, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + gene_only = base(gene_preds, targets, mask=mask) + fallback = aux(gene_preds, targets, mask=mask, pathway_preds=None) + assert torch.allclose(gene_only, fallback) + + def test_lambda_zero_disables(self, pathway_tensors): + """lambda_pathway=0 should produce the same result as gene loss.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.0) + gene_only = base(gene_preds, targets, mask=mask) + total = aux(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + assert torch.allclose(gene_only, total) + + def test_zinb_integration(self): + """AuxiliaryPathwayLoss should work with ZINB (pi, mu, theta) output.""" + from spatial_transcript_former.training.losses import ( + ZINBLoss, + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + torch.manual_seed(42) + B, N, G, P = 2, 10, 50, 10 + pi = torch.rand(B, N, G) + mu = torch.rand(B, N, G) * 5 + 1.0 + theta = torch.rand(B, N, G) * 5 + 1.0 + zinb_preds = (pi, mu, theta) + + targets = torch.randint(0, 10, (B, N, G)).float() + pw_preds = torch.randn(B, N, P) + pw_matrix = (torch.rand(P, G) > 0.8).float() + mask = torch.zeros(B, N, dtype=torch.bool) + mask[0, 5:] = True + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, ZINBLoss(), lambda_pathway=1.0) + loss = loss_fn(zinb_preds, targets, mask=mask, pathway_preds=pw_preds) + + assert loss.isfinite() + assert loss.item() > 0 + + def test_perfect_match_zero_aux(self, pathway_tensors): + """When pathway preds perfectly correlate with truth, aux loss contribution is 0.""" + gene_preds, targets, _, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + aux = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=1.0) + + # Compute ground truth pathways + with torch.no_grad(): + target_pathways = torch.matmul(targets, pw_matrix.T) + + gene_loss = base(gene_preds, targets, mask=mask) + # Use target_pathways as pathway_preds + total_loss = aux(gene_preds, targets, mask=mask, pathway_preds=target_pathways) + + # Total loss should roughly equal gene loss (since 1-PCC should be 0) + assert total_loss.item() == pytest.approx(gene_loss.item(), abs=1e-5) + + def test_numerical_stability_empty_spots(self, pathway_tensors): + """Test stability when targets for some pathways are all zero.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + + # Zero out targets for the first pathway + # Matrix is (P, G), so genes involved in pathway 0 + genes_in_p0 = pw_matrix[0].bool() + targets[..., genes_in_p0] = 0.0 + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=1.0) + loss = loss_fn(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + + assert loss.isfinite() + assert not torch.isnan(loss) + + def test_lambda_scaling(self, pathway_tensors): + """Doubling lambda_pathway should increase the aux contribution accordingly.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + + aux1 = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.5) + aux2 = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=1.0) + + loss1 = aux1(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + loss2 = aux2(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + gene_loss = base(gene_preds, targets, mask=mask) + + term1 = loss1 - gene_loss + term2 = loss2 - gene_loss + + # term2 should be approx 2 * term1 + assert term2.item() == pytest.approx(2 * term1.item(), rel=1e-4) + + def test_hallmark_integration(self): + """Test with a real (though small) MSigDB Hallmark matrix.""" + from spatial_transcript_former.data.pathways import get_pathway_init + from spatial_transcript_former.training.losses import ( + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + # Subset of genes known to be in MSigDB Hallmarks + gene_list = ["TP53", "MYC", "VEGFA", "VIM", "CDH1", "SNAI1", "AXIN2", "MLH1"] + B, N, G = 2, 5, len(gene_list) + + # Get real Hallmark matrix (only for our 8 genes) + matrix, names = get_pathway_init(gene_list, verbose=False) + P = len(names) + + preds = torch.randn(B, N, G) + targets = torch.randn(B, N, G).abs() + pw_preds = torch.randn(B, N, P) + + loss_fn = AuxiliaryPathwayLoss(matrix, MaskedMSELoss()) + loss = loss_fn(preds, targets, pathway_preds=pw_preds) + + assert loss.isfinite() + assert len(names) > 0 + + def test_hallmark_signal_detection(self): + """Verify that gene patterns aligned with a hallmark reduce the aux loss.""" + from spatial_transcript_former.training.losses import ( + AuxiliaryPathwayLoss, + MaskedMSELoss, + ) + + # Define 4 genes and 2 pathways + # P0: G0, G1 + # P1: G2, G3 + pw_matrix = torch.tensor( + [[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]], dtype=torch.float32 + ) + gene_list = ["G0", "G1", "G2", "G3"] + B, N, G = 1, 10, 4 + P = 2 + + # Case 1: Random targets + torch.manual_seed(42) + targets = torch.randn(B, N, G).abs() + # Predictions match targets for genes + gene_preds = targets.clone() + + # Pathway preds are random + pw_preds_random = torch.randn(B, N, P) + + loss_fn = AuxiliaryPathwayLoss(pw_matrix, MaskedMSELoss(), lambda_pathway=1.0) + loss_random = loss_fn(gene_preds, targets, pathway_preds=pw_preds_random) + + # Case 2: Pathway preds perfectly match truth (which is targets @ matrix.T) + pw_truth = torch.matmul(targets, pw_matrix.T) + loss_perfect = loss_fn(gene_preds, targets, pathway_preds=pw_truth) + + # Case 3: Gene expression is specifically high for P0, and pw_preds are high for P0 + # If the spatial correlation is high, aux loss should be low. + assert loss_perfect.item() < loss_random.item() diff --git a/tests/test_model_backbones.py b/tests/test_model_backbones.py index cb4fc79..090dbb3 100644 --- a/tests/test_model_backbones.py +++ b/tests/test_model_backbones.py @@ -38,9 +38,10 @@ def test_interaction_model_backbone(): num_genes=num_genes, backbone_name="resnet50", pretrained=False ) - # Test with raw image input + # Test with raw image input (single patch => S=1) x = torch.randn(4, 3, 224, 224) - out = model(x) + rel_coords = torch.randn(4, 1, 2) + out = model(x, rel_coords=rel_coords) assert out.shape == (4, num_genes) diff --git a/tests/test_models.py b/tests/test_models.py index e16f055..39b2b1c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,54 +1,30 @@ import pytest import torch from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.models.interaction import MultimodalFusion def test_interaction_output_shape(mock_image_batch): """ EDUCATIONAL: This test verifies that the SpatialTranscriptFormer correctly maps a batch of histology images to a high-dimensional gene expression vector. - - Architecture: - Input (x) -> Backbone -> Interaction (Cross-Attention) -> Gene Reconstructor -> Output (y) """ num_genes = 1000 model = SpatialTranscriptFormer(num_genes=num_genes) - output = model(mock_image_batch) - - # Verify shape: (Batch Size, Number of Genes) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_multimodal_fusion_quadrant_masking(): - """ - EDUCATIONAL: This test verifies Jaume Et Al. Equation 1 quadrant masking. - We verify that masking a quadrant (P2H: Pathway to Histology) effectively - blocks information flow, resulting in zero gradients for histology tokens. - - Matrix A = [ P2P P2H ] - [ H2P H2H ] - """ - dim, Np, Nh, B = 64, 5, 10, 2 - - # Mode: Mask P2H (Pathways cannot see Histology) - fusion = MultimodalFusion(dim, n_heads=4, n_layers=1, masked_quadrants=["P2H"]) - - p_tokens = torch.randn(B, Np, dim, requires_grad=True) - h_tokens = torch.randn(B, Nh, dim, requires_grad=True) + # Must provide rel_coords since use_spatial_pe defaults to True + B = mock_image_batch.shape[0] + if mock_image_batch.dim() == 5: + S = mock_image_batch.shape[1] + elif mock_image_batch.dim() == 4: + S = 1 + else: + S = mock_image_batch.shape[1] - # Forward pass: Results in Pathway tokens (Bottleneck) - out = fusion(p_tokens, h_tokens) + rel_coords = torch.randn(B, S, 2) + output = model(mock_image_batch, rel_coords=rel_coords) - # Backpropagate signal from output - out.sum().backward() - - # ASSERTION: If P2H is masked, histology tokens (H) should receive ZERO signal from the output - assert h_tokens.grad is not None - assert ( - h_tokens.grad.abs().sum() < 1e-6 - ), "Histology tokens should have no gradient if P2H quadrant is masked" + # Verify shape: (Batch Size, Number of Genes) + assert output.shape == (B, num_genes) def test_sparsity_regularization_loss(): @@ -66,21 +42,3 @@ def test_sparsity_regularization_loss(): # Expect a positive scalar (L1 norm of reconstruction weights) assert sparsity_loss > 0 assert sparsity_loss.dim() == 0 - - -def test_interaction_with_masking(mock_image_batch): - """ - EDUCATIONAL: Verifies that the unified interaction model works with - different quadrant masking configurations. - """ - num_genes = 100 - - # Test with H2H masking (Default) - model = SpatialTranscriptFormer(num_genes=num_genes, masked_quadrants=["H2H"]) - out = model(mock_image_batch) - assert out.shape == (mock_image_batch.shape[0], num_genes) - - # Test with all quadrants open - model_all = SpatialTranscriptFormer(num_genes=num_genes, masked_quadrants=[]) - out_all = model_all(mock_image_batch) - assert out_all.shape == (mock_image_batch.shape[0], num_genes) diff --git a/tests/test_neighborhood.py b/tests/test_neighborhood.py index eb363ce..4198bc9 100644 --- a/tests/test_neighborhood.py +++ b/tests/test_neighborhood.py @@ -32,32 +32,5 @@ def test_neighborhood_model_forward(): print("Neighborhood forward pass test passed!") -def test_spatial_masking(): - B = 1 - S = 4 - G = 100 - # Center at (0,0), neighbors at (10,0), (100,0), (1000,0) - rel_coords = torch.tensor( - [[[0, 0], [10, 0], [100, 0], [1000, 0]]], dtype=torch.float32 - ) - - # Set mask radius to 50 -> neighbors 2 and 3 should be masked - model = SpatialTranscriptFormer(num_genes=G, mask_radius=50) - - mask = model._generate_spatial_mask(rel_coords) - # Expected mask: [False, False, True, True] (True means ignore) - - print(f"Rel Coords: {rel_coords}") - print(f"Generated Mask: {mask}") - - assert mask[0, 0] == False - assert mask[0, 1] == False - assert mask[0, 2] == True - assert mask[0, 3] == True - - print("Spatial masking logic test passed!") - - if __name__ == "__main__": test_neighborhood_model_forward() - test_spatial_masking() diff --git a/tests/test_nystrom.py b/tests/test_nystrom.py deleted file mode 100644 index 6b9c92e..0000000 --- a/tests/test_nystrom.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import pytest -from spatial_transcript_former.models import SpatialTranscriptFormer - - -def test_nystrom_jaume_mode(mock_image_batch): - """ - EDUCATIONAL: Verifies that Nystrom Attention works in 'jaume' (early fusion) mode. - Nystrom provides linear complexity, allowing for multimodal interaction - without the quadratic memory wall of standard self-attention. - """ - num_genes = 100 - model = SpatialTranscriptFormer( - num_genes=num_genes, use_nystrom=True, num_landmarks=32 # Small for testing - ) - - output = model(mock_image_batch) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_nystrom_no_quadrant_mask_support(mock_image_batch): - """ - EDUCATIONAL: Nystrom Attention (linear complexity) does not support - standard 2D quadrant masking. This test verifies the model still - executes without crashing, but we acknowledge the mask is ignored. - """ - num_genes = 100 - model = SpatialTranscriptFormer( - num_genes=num_genes, - use_nystrom=True, - masked_quadrants=["H2H"], - num_landmarks=32, - ) - - output = model(mock_image_batch) - assert output.shape == (mock_image_batch.shape[0], num_genes) - - -def test_nystrom_scalability(): - """ - EDUCATIONAL: This test verifies that Nystrom mode can handle very large - histology sequences that would typically crash a standard transformer. - """ - B, S, D = 1, 1024, 512 # Large neighborhood - num_genes = 10 - - # We'll use mock projection features instead of raw images to save memory in CI - # token_dim=512 is used in SpatialTranscriptFormer - features = torch.randn(B, S, 2048) # Backbone output dim - - model = SpatialTranscriptFormer( - num_genes=num_genes, use_nystrom=True, num_landmarks=128 - ) - - # Forward pass on a sequence of 1024 tokens - with torch.no_grad(): - output = model(features) - - assert output.shape == (B, num_genes) diff --git a/tests/test_quadrant_masking.py b/tests/test_quadrant_masking.py deleted file mode 100644 index e43bf1d..0000000 --- a/tests/test_quadrant_masking.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn -import sys -import os - -# Add src to path -sys.path.append(os.path.abspath("src")) - -from spatial_transcript_former.models.interaction import MultimodalFusion - - -def test_quadrant_masking_interactions(): - dim = 64 - n_heads = 4 - n_layers = 1 - Np = 5 - Nh = 10 - B = 2 - - # Mode 1: Full Attention - fusion_full = MultimodalFusion(dim, n_heads, n_layers) - - # Mode 2: Mask H2H - fusion_masked = MultimodalFusion(dim, n_heads, n_layers, masked_quadrants=["H2H"]) - - p_tokens = torch.randn(B, Np, dim, requires_grad=True) - h_tokens = torch.randn(B, Nh, dim, requires_grad=True) - - # Forward full - out_full = fusion_full(p_tokens, h_tokens) - loss_full = out_full.sum() - loss_full.backward() - - h_grad_full = h_tokens.grad.abs().sum() - print(f"H tokens grad (Full): {h_grad_full.item()}") - - # Reset grad - h_tokens.grad.zero_() - p_tokens.grad.zero_() - - # Forward masked H2H - # NOTE: Masking H2H in self-attention technically means H tokens don't attend to H tokens. - # But H tokens can still attend to P tokens (H2P) unless we mask that too. - # If we mask H2H, H tokens are purely transformed by their interaction with P tokens. - out_masked = fusion_masked(p_tokens, h_tokens) - loss_masked = out_masked.sum() - loss_masked.backward() - - h_grad_masked = h_tokens.grad.abs().sum() - print(f"H tokens grad (Masked H2H): {h_grad_masked.item()}") - - # Forward masked H2H AND H2P (H tokens isolated) - h_tokens.grad.zero_() - fusion_isolated = MultimodalFusion( - dim, n_heads, n_layers, masked_quadrants=["H2H", "H2P"] - ) - out_iso = fusion_isolated(p_tokens, h_tokens) - out_iso.sum().backward() - - h_grad_iso = h_tokens.grad.abs().sum() - print(f"H tokens grad (H2H + H2P Masked): {h_grad_iso.item()}") - - # In H2H + H2P masked mode, output depends on P tokens attending to H and P. - # H tokens shouldn't receive any gradients from the output if they are isolated - # and the output only contains the transformed P tokens. - # WAIT: P tokens STILL attend to H tokens (P2H). So H tokens WILL get gradients. - - # Let's verify P2H masking - h_tokens.grad.zero_() - fusion_no_hist = MultimodalFusion(dim, n_heads, n_layers, masked_quadrants=["P2H"]) - out_no_hist = fusion_no_hist(p_tokens, h_tokens) - out_no_hist.sum().backward() - h_grad_no_hist = h_tokens.grad.abs().sum() - print(f"H tokens grad (P2H Masked): {h_grad_no_hist.item()}") - - assert ( - h_grad_no_hist < 1e-5 - ), "H tokens should not receive gradients if P2H is masked" - print("Quadrant masking verification passed!") - - -if __name__ == "__main__": - test_quadrant_masking_interactions() diff --git a/tests/test_spatial.py b/tests/test_spatial.py index 5a5e094..b97c59d 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -18,38 +18,3 @@ def test_neighborhood_forward_pass(mock_neighborhood_batch, mock_rel_coords): output = model(mock_neighborhood_batch, rel_coords=mock_rel_coords) assert output.shape == (mock_neighborhood_batch.shape[0], num_genes) - - -def test_distance_based_spatial_masking(): - """ - EDUCATIONAL: This test verifies that the 'mask_radius' correctly suppresses - interactions with distant patches in the neighborhood. - - We place patches at different distances and verify that the generated - attention mask ignores those beyond the radius. - """ - G = 100 - # Relative Coords: Center(0,0), Near(10,0), Far(100,0), Very Far(1000,0) - rel_coords = torch.tensor( - [ - [ - [0, 0], # Index 0 (Center) - [10, 0], # Index 1 (Near) - [100, 0], # Index 2 (Far) - [1000, 0], # Index 3 (Very Far) - ] - ], - dtype=torch.float32, - ) - - # Set radius to 50: Only Center and Near should be visible - model = SpatialTranscriptFormer(num_genes=G, mask_radius=50) - - # Internal method creates a boolean mask (True = Ignore) - mask = model._generate_spatial_mask(rel_coords) - - # Expected: [False (Seen), False (Seen), True (Masked), True (Masked)] - assert mask[0, 0] == False - assert mask[0, 1] == False - assert mask[0, 2] == True - assert mask[0, 3] == True diff --git a/tests/test_spatial_alignment.py b/tests/test_spatial_alignment.py index f5ac0b0..611ec2e 100644 --- a/tests/test_spatial_alignment.py +++ b/tests/test_spatial_alignment.py @@ -5,28 +5,14 @@ def test_spatial_mixing_with_large_coordinates(): """ - Verifies that the model correctly normalizes large pixel coordinates (e.g. 256px steps) - so that local mixing can occur between adjacent patches. + Verifies that the model correctly handles large pixel coordinates (e.g. 256px steps) + and that gradient flows between patches via the pathway bottleneck. """ - # ResNet50 output dim is 2048, so we must use token_dim=2048 or project it. - # The model has an image_proj layer: Linear(image_feature_dim, token_dim). - # Wait, the error was 2x64 and 2048x64? - # Ah, the error "mat1 and mat2 shapes cannot be multiplied (2x64 and 2048x64)" - # suggests the input x has dim 64, but the first layer (backbone) expects something else? - # Actually, the model forward() expects images (B, C, H, W) OR features (B, N, D). - # If passing features, D must match image_feature_dim (2048 for ResNet). - - # Let's verify model source: - # if x.dim() == 3: features = x - # memory = self.image_proj(features) - # image_proj is Linear(2048, token_dim). - - # So input x mus be (B, N, 2048). dim = 2048 token_dim = 64 model = SpatialTranscriptFormer( - num_genes=10, token_dim=token_dim, n_layers=1, use_spatial_pe=True + num_genes=10, token_dim=token_dim, n_layers=2, use_spatial_pe=True ) # Create two patches that are physically adjacent (256px apart) but logically neighbors @@ -40,8 +26,8 @@ def test_spatial_mixing_with_large_coordinates(): output = model(x, rel_coords=coords) # Check gradient flow from Patch 1 to Patch 0's output - # If mixing works, output[0] should be influenced by input[1] via the LocalPatchMixer - # We sum output[0] and check grad w.r.t x[0, 1] + # With the pathway bottleneck, both patches attend to pathway tokens, so + # indirect gradient flow exists through the shared pathway embeddings. loss = output[0, 0].sum() loss.backward() @@ -49,8 +35,7 @@ def test_spatial_mixing_with_large_coordinates(): assert grad_neighbor > 0.0, ( f"Gradient from neighbor is zero ({grad_neighbor}). " - "The model failed to mix features between patches separated by 256px. " - "This indicates the LocalPatchMixer is treating them as distant neighbors." + "The model failed to propagate gradients between patches via pathway tokens." ) diff --git a/tests/test_spatial_interaction.py b/tests/test_spatial_interaction.py index 308e861..b887043 100644 --- a/tests/test_spatial_interaction.py +++ b/tests/test_spatial_interaction.py @@ -3,58 +3,163 @@ from unittest.mock import MagicMock from spatial_transcript_former.models.interaction import ( SpatialTranscriptFormer, - LocalPatchMixer, - GraphPatchMixer, + LearnedSpatialEncoder, + VALID_INTERACTIONS, ) from spatial_transcript_former.training.engine import train_one_epoch, validate -def test_normalize_coords_finds_correct_step(): - """ - Test that _normalize_coords correctly infers the grid spacing (e.g., 224) - even when absolute coordinates are very large, which breaks the current median - absolute value heuristic. - """ - model = SpatialTranscriptFormer(num_genes=10, use_nystrom=False) +def test_n_layers_enforcement(): + """n_layers < 2 with h2h blocked should raise ValueError.""" + with pytest.raises(ValueError, match="n_layers must be >= 2"): + SpatialTranscriptFormer( + num_genes=50, n_layers=1, interactions=["p2p", "p2h", "h2p"] + ) + - # Simulate a 2x2 grid of patches with patch size 224 in absolute pixel coords - coords = torch.tensor( - [[10000.0, 20000.0], [10224.0, 20000.0], [10000.0, 20224.0], [10224.0, 20224.0]] - ).unsqueeze( - 0 - ) # Output shape: (1, 4, 2) +def test_n_layers_ok_with_full_interactions(): + """n_layers=1 is allowed when h2h is enabled (full attention).""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=1, + interactions=["p2p", "p2h", "h2p", "h2h"], + ) + assert model is not None - normalized = model._normalize_coords(coords) - # LocalPatchMixer subtracts the minimum coordinate to create a zero-indexed grid. - min_c = normalized.min(dim=1, keepdim=True)[0] - grid_coords = normalized - min_c +def test_invalid_interaction_key(): + """Unknown interaction keys should raise ValueError.""" + with pytest.raises(ValueError, match="Unknown interaction keys"): + SpatialTranscriptFormer(num_genes=50, interactions=["p2p", "x2y"]) - expected_grid = torch.tensor( - [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]] - ).unsqueeze(0) - # The current heuristic fails this because it divides by ~15000 (median) - assert torch.allclose( - grid_coords, expected_grid - ), f"Coordinate normalization failed, returned:\n{grid_coords}" +@pytest.mark.parametrize( + "interactions", + [ + ["p2p", "p2h", "h2p", "h2h"], # full + ["p2p", "p2h", "h2p"], # bottleneck + ["p2p", "p2h"], # pathway-only + ], +) +def test_interaction_combinations(interactions): + """Various interaction combos should produce correct output shapes.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + interactions=interactions, + ) + B, N, D = 2, 5, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + out = model(features, rel_coords=coords) + assert out.shape == (B, 50) + + +def test_full_interactions_returns_no_mask(): + """When all interactions are enabled, _build_interaction_mask returns None.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + interactions=["p2p", "p2h", "h2p", "h2h"], + ) + mask = model._build_interaction_mask(p=10, s=20, device=torch.device("cpu")) + assert mask is None + +def test_missing_coords_raises(): + """use_spatial_pe=True without coords should raise ValueError.""" + model = SpatialTranscriptFormer(num_genes=50, token_dim=64, n_heads=4, n_layers=2) + features = torch.randn(2, 10, 2048) + with pytest.raises(ValueError, match="rel_coords was not provided"): + model(features) -def test_engine_passes_coords_to_dense_forward(): + +def test_spatial_transcript_former_dense_forward(): + """Instantiate the model and ensure forward() with return_dense=True executes properly.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + ) + + B, N, D = 2, 10, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + + # 1. Standard Forward Pass + out = model(features, rel_coords=coords) + assert out.shape == ( + B, + 50, + ), f"Expected global output shape {(B, 50)}, got {out.shape}" + + # Reset grads + model.zero_grad() + + # 2. Dense Forward Pass + out_dense = model.forward(features, rel_coords=coords, return_dense=True) + assert out_dense.shape == ( + B, + N, + 50, + ), f"Expected dense output shape {(B, N, 50)}, got {out_dense.shape}" + + out_dense.sum().backward() + assert ( + model.fusion_engine.layers[0].self_attn.in_proj_weight.grad is not None + ), "Gradients did not flow through the transformer in dense pass." + + +def test_unified_pathway_scoring(): + """Both global and dense modes should produce pathway_scores from dot-products.""" + model = SpatialTranscriptFormer( + num_genes=50, + token_dim=64, + n_heads=4, + n_layers=2, + num_pathways=10, + ) + + B, N, D = 2, 5, 2048 + features = torch.randn(B, N, D) + coords = torch.randn(B, N, 2) + + # Global mode returns (gene_expression, pathway_scores) + gene_expr, pw_scores = model(features, rel_coords=coords, return_pathways=True) + assert gene_expr.shape == (B, 50) + assert pw_scores.shape == (B, 10) # (B, P) from pooled dot-product + + # Dense mode returns (gene_expression, pathway_scores) + gene_expr_d, pw_scores_d = model( + features, rel_coords=coords, return_pathways=True, return_dense=True + ) + assert gene_expr_d.shape == (B, N, 50) + assert pw_scores_d.shape == (B, N, 10) # (B, S, P) from per-patch dot-product + + +def test_engine_passes_coords_to_forward(): """ - Verify that the training engine passes the generated spatial coordinates to - the model in whole_slide mode. + Verify that the training engine passes spatial coordinates to + model.forward() in whole_slide mode. """ model = SpatialTranscriptFormer(num_genes=10) - # Mock to track if coords are passed - model.forward_dense = MagicMock(return_value=torch.randn(2, 5, 10)) + # Mock forward to track calls + original_forward = model.forward + model.forward = MagicMock(return_value=torch.randn(2, 5, 10)) model.get_sparsity_loss = MagicMock(return_value=torch.tensor(0.0)) fake_coords = torch.randn(2, 5, 2) fake_mask = torch.zeros(2, 5).bool() - # Datloader yielding (feats, genes, coords, mask) + # Dataloader yielding (feats, genes, coords, mask) loader = [(torch.randn(2, 5, 512), torch.randn(2, 5, 10), fake_coords, fake_mask)] optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -64,12 +169,12 @@ def dummy_criterion(p, t, mask=None): train_one_epoch(model, loader, dummy_criterion, optimizer, "cpu", whole_slide=True) - model.forward_dense.assert_called_once() - kwargs = model.forward_dense.call_args.kwargs + model.forward.assert_called_once() + kwargs = model.forward.call_args.kwargs - assert "coords" in kwargs, "Engine did not pass 'coords' kwargs to forward_dense!" + assert "rel_coords" in kwargs, "Engine did not pass 'rel_coords' kwargs to forward!" assert torch.allclose( - kwargs["coords"], fake_coords + kwargs["rel_coords"], fake_coords ), "Engine passed wrong coordinate tensor!" @@ -78,7 +183,7 @@ def test_engine_validate_passes_coords(): Verify validation loop passes coords. """ model = SpatialTranscriptFormer(num_genes=10) - model.forward_dense = MagicMock(return_value=torch.randn(2, 5, 10)) + model.forward = MagicMock(return_value=torch.randn(2, 5, 10)) fake_coords = torch.randn(2, 5, 2) loader = [ @@ -95,79 +200,12 @@ def dummy_criterion(p, t, mask=None): validate(model, loader, dummy_criterion, "cpu", whole_slide=True) - model.forward_dense.assert_called_once() - kwargs = model.forward_dense.call_args.kwargs + model.forward.assert_called_once() + kwargs = model.forward.call_args.kwargs assert ( - "coords" in kwargs - ), "Validate engine did not pass 'coords' kwargs to forward_dense!" + "rel_coords" in kwargs + ), "Validate engine did not pass 'rel_coords' kwargs to forward!" assert torch.allclose( - kwargs["coords"], fake_coords + kwargs["rel_coords"], fake_coords ), "Validate engine passed wrong coordinate tensor!" - - -def test_graph_patch_mixer(): - """Verify that the GraphPatchMixer correctly performs message passing over a k-NN graph.""" - B, N, D = 2, 10, 32 - k = 3 - mixer = GraphPatchMixer(dim=D, k=k, heads=4) - - x = torch.randn(B, N, D) - coords = torch.randn(B, N, 2) - - # 1. Forward Pass - out = mixer(x, coords) - - # 2. Shape Verification - assert out.shape == (B, N, D), f"Expected shape {(B, N, D)}, got {out.shape}" - - # 3. Gradient Flow Verification - out.sum().backward() - assert ( - mixer.to_qkv.weight.grad is not None - ), "Gradients did not flow through the GAT layer." - - -def test_spatial_transcript_former_with_gnn_refiner(): - """Instantiate the model with GNN refiner and ensure forward() and forward_dense() execute properly.""" - model = SpatialTranscriptFormer( - num_genes=50, - token_dim=64, - n_heads=4, - n_layers=1, - early_mixer=None, - late_refiner="gnn", - ) - - B, N, D = 2, 10, 2048 - features = torch.randn(B, N, D) - coords = torch.randn(B, N, 2) - - # Check that early_mixer is None and late_refiner is initialized - assert model.early_mixer is None, "early_mixer should be None" - assert ( - hasattr(model, "late_refiner") and model.late_refiner is not None - ), "late_refiner not initialized" - - # 1. Standard Forward Pass - out = model(features, rel_coords=coords) - assert out.shape == ( - B, - 50, - ), f"Expected dense output shape {(B, 50)}, got {out.shape}" - - # Reset grads - model.zero_grad() - - # 2. Dense Forward Pass (This is the one that uses the Graph Refiner explicitly) - out_dense = model.forward_dense(features, coords=coords) - assert out_dense.shape == ( - B, - N, - 50, - ), f"Expected dense output shape {(B, N, 50)}, got {out_dense.shape}" - - out_dense.sum().backward() - assert ( - model.late_refiner.to_qkv.weight.grad is not None - ), "Gradients did not flow through the late_refiner in dense pass." From 724d7b864495e7e6d0117724f3c4b0ae0d24760b Mon Sep 17 00:00:00 2001 From: BenjaminIsaac0111 <12176376+BenjaminIsaac0111@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:02:41 +0000 Subject: [PATCH 2/4] docs: implement branch protections and code ownership - Update CONTRIBUTING.md to formally document main branch protection requirements, including mandatory PR reviews, status checks, and linear history. - Create .github/CODEOWNERS to automatically assign @BenjaminIsaac0111 as a reviewer for all project modules and documentation. - Align contribution guidelines with the new Repository Ruleset workflow. --- .github/CODEOWNERS | 17 +++++++++++++++++ CONTRIBUTING.md | 9 +++++++++ 2 files changed, 26 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..c25fab4 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,17 @@ +# Global owners (catch-all for all files) +* @BenjaminIsaac0111 + +# Core Model and Training Logic +src/spatial_transcript_former/models/ @BenjaminIsaac0111 +src/spatial_transcript_former/training/ @BenjaminIsaac0111 + +# Data Management and Scripts +src/spatial_transcript_former/data/ @BenjaminIsaac0111 +scripts/ @BenjaminIsaac0111 + +# Documentation +docs/ @BenjaminIsaac0111 +*.md @BenjaminIsaac0111 + +# GitHub Actions and Infrastructure +.github/ @BenjaminIsaac0111 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bd90cd8..bcebc17 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,6 +55,15 @@ bash test.sh # Linux 3. **Documentation**: Update relevant files in `docs/` and the `README.md` if your change affects usage. 4. **Verification**: Ensure all CI checks (GitHub Actions) pass. +### Branch Protections + +To maintain code quality and stability, the following protections are enforced on the `main` branch: + +- **Require Pull Request Reviews**: All merges to `main` require at least one approval from a project maintainer. +- **Required Status Checks**: The `CI` workflow must pass successfully before a PR can be merged. This includes formatting checks (`black`) and the full test suite (`pytest`). +- **No Direct Pushes**: Pushing directly to `main` is disabled. All changes must go through the Pull Request process. +- **Linear History**: We prefer **Squash and Merge** to keep the `main` branch history clean and concise. + ## Contact For questions regarding commercial licensing or complex architectural changes, please contact the author directly. From 8cf2cc291cf331ef3987235cd4b9e81ee05d371b Mon Sep 17 00:00:00 2001 From: BenjaminIsaac0111 <12176376+BenjaminIsaac0111@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:41:54 +0000 Subject: [PATCH 3/4] Fixed TypeError: Resolved a TypeError in SpatialTranscriptFormer by correctly placing enable_nested_tensor=False in the TransformerEncoder constructor. Pytest Configuration: Configured pyproject.toml to suppress common non-critical warnings (Deprecation, Matplotlib, etc.). Demonstration Test: Added tests/test_warnings.py to illustrate how to handle warnings using pytest.warns. --- .../models/interaction.py | 5 ++- tests/test_neighborhood.py | 36 --------------- tests/test_spatial.py | 20 --------- tests/test_warnings.py | 44 +++++++++++++++++++ 4 files changed, 48 insertions(+), 57 deletions(-) delete mode 100644 tests/test_neighborhood.py delete mode 100644 tests/test_spatial.py create mode 100644 tests/test_warnings.py diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 1832f7d..695cc77 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -152,7 +152,10 @@ def __init__( batch_first=True, norm_first=True, ) - self.fusion_engine = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + self.fusion_engine = nn.TransformerEncoder( + encoder_layer, num_layers=n_layers, enable_nested_tensor=False + ) # Learnable temperature for cosine similarity scoring # Initialized to log(1/0.07) ≈ 2.66 following CLIP convention diff --git a/tests/test_neighborhood.py b/tests/test_neighborhood.py deleted file mode 100644 index 4198bc9..0000000 --- a/tests/test_neighborhood.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import numpy as np -import sys -import os - -# Add src to path -sys.path.append(os.path.abspath("src")) - -from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.data.dataset import HEST_Dataset - - -def test_neighborhood_model_forward(): - B = 2 - S = 9 # 1 center + 8 neighbors - C, H, W = 3, 224, 224 - G = 1000 - - model = SpatialTranscriptFormer(num_genes=G) - - # Input sequence - x = torch.randn(B, S, C, H, W) - rel_coords = torch.randn(B, S, 2) - - # Forward pass - out = model(x, rel_coords=rel_coords) - - print(f"Neighborhood Input shape: {x.shape}") - print(f"Output shape: {out.shape}") - - assert out.shape == (B, G), f"Expected (B, G), got {out.shape}" - print("Neighborhood forward pass test passed!") - - -if __name__ == "__main__": - test_neighborhood_model_forward() diff --git a/tests/test_spatial.py b/tests/test_spatial.py deleted file mode 100644 index b97c59d..0000000 --- a/tests/test_spatial.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -import torch -from spatial_transcript_former.models import SpatialTranscriptFormer - - -def test_neighborhood_forward_pass(mock_neighborhood_batch, mock_rel_coords): - """ - EDUCATIONAL: This test verifies that the model can process a sequence of - histology patches (Neighborhood Mode). - - Input shape: (Batch, Sequence Length, Channels, Height, Width) - Output shape: (Batch, Number of Genes) - """ - num_genes = 1000 - model = SpatialTranscriptFormer(num_genes=num_genes) - - # Forward pass with neighborhood sequence and relative coordinates - output = model(mock_neighborhood_batch, rel_coords=mock_rel_coords) - - assert output.shape == (mock_neighborhood_batch.shape[0], num_genes) diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 0000000..f94ead4 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,44 @@ +import pytest +import warnings + + +def function_that_warns(): + warnings.warn("This is a deprecated feature", DeprecationWarning) + return True + + +def function_that_warns_user(): + warnings.warn("FigureCanvasAgg is non-interactive", UserWarning) + return True + + +def test_demonstrate_warning_assertion(): + """ + Demonstrates how to assert that a specific warning is raised. + This is useful for ensuring your code warns users correctly. + """ + with pytest.warns(DeprecationWarning, match="deprecated feature"): + result = function_that_warns() + assert result is True + + +def test_global_filter_demonstration(): + """ + This test will pass without showing warnings in the output + because we added filters to pyproject.toml. + + Specifically 'FigureCanvasAgg is non-interactive' is filtered. + """ + result = function_that_warns_user() + assert result is True + + +def test_how_to_catch_and_ignore_locally(): + """ + If you want to ignore a warning locally in a specific test + without adding it to the global pyproject.toml. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = function_that_warns() + assert result is True From 39530e8ada6317f86e620760a6363ef25b245f0f Mon Sep 17 00:00:00 2001 From: BenjaminIsaac0111 <12176376+BenjaminIsaac0111@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:44:30 +0000 Subject: [PATCH 4/4] Forgot this one, oops hehe --- pyproject.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 62f7a4f..e4064ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,3 +44,13 @@ stf-build-vocab = "spatial_transcript_former.data.build_vocab:main" [tool.setuptools.packages.find] where = ["src"] include = ["spatial_transcript_former*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore:FigureCanvasAgg is non-interactive:UserWarning", + "ignore:os.fork:RuntimeWarning", + "ignore:torch.utils._pytree._register_pytree_node is deprecated:UserWarning", +]