diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bcebc17..d3f5ccd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -35,18 +35,23 @@ We use `black` for formatting and `flake8` for linting. Please ensure your code ```bash black . -flake8 src/ ``` -### 3. Testing +### 4. AI-Assisted Development -All new features must include unit tests in the `tests/` directory. We use `pytest` for our test suite. +We welcome contributions developed with the assistance of AI tools (e.g., Copilot, ChatGPT, Claude, or agentic frameworks). However, to ensure the long-term maintainability and integrity of the project: -```bash -# Run all tests -.\test.ps1 # Windows -bash test.sh # Linux -``` +- **Ownership**: You are ultimately responsible for the code you submit. Do not commit code you do not fully understand. +- **Explainability**: During the review process, you must be able to explain the logic, design decisions, and any subtle side effects of the AI-suggested changes. +- **Verification**: AI-generated code must strictly follow our coding standards, naming conventions, and architectural patterns. It must be accompanied by robust tests (see our [Testing Guide](docs/TESTING.md)). + +### 3. Testing & Quality Assurance + +All new features must be accompanied by relevant tests in the `tests/` directory natively using `pytest`. + +We highly encourage rigorous testing approaches such as **Mutation Testing** (via `cosmic-ray`) for critical model components to prevent surviving mutants. + +For full details on our testing requirements, how to run the test suites locally, and our guidelines on mutation testing, please read the [Testing Guide](docs/TESTING.md). ## Pull Request Process diff --git a/Download-BowelCancer.ps1 b/Download-BowelCancer.ps1 deleted file mode 100644 index 3422441..0000000 --- a/Download-BowelCancer.ps1 +++ /dev/null @@ -1,105 +0,0 @@ -# PowerShell Script to run and track HEST Bowel Cancer download progress - -$TargetDir = "A:\hest_data" -$TargetSizeGB = 95 -$TargetSizeFactor = 1GB * $TargetSizeGB -$PythonExe = "C:\Users\wispy\miniconda3\envs\SpatialTranscriptFormer\python.exe" - -Write-Host "--- HEST Bowel Cancer Download Tracker ---" -ForegroundColor Cyan -Write-Host "Target Directory: $TargetDir" -Write-Host "Estimated Total Size: $TargetSizeGB GB" -Write-Host "" - -# Check if Python script is running using CIM for CommandLine access -function Get-DownloadProcess { - return Get-CimInstance Win32_Process -Filter "Name = 'python.exe' and CommandLine like '%download_bowel_cancer.py%'" -} - -$Process = Get-DownloadProcess - -if ($null -eq $Process) { - Write-Host "Download script is not running. Starting it now..." -ForegroundColor Yellow - if (-not (Test-Path $PythonExe)) { - Write-Host "Error: Python executable not found at $PythonExe" -ForegroundColor Red - exit - } - # Start the process in the background - Start-Process $PythonExe -ArgumentList "scripts/download_bowel_cancer.py" -NoNewWindow - Start-Sleep -Seconds 5 # Give it a bit more time to initialize - $Process = Get-DownloadProcess - - if ($null -eq $Process) { - Write-Host "Error: Failed to start the download script or it crashed immediately." -ForegroundColor Red - Write-Host "Please check for error messages above." -ForegroundColor Yellow - exit - } - Write-Host "Started download script (PID: $($Process.ProcessId))." -ForegroundColor Green -} -else { - Write-Host "Download script is already running (PID: $($Process.ProcessId))." -ForegroundColor Green -} - -$StartTime = Get-Date -$DownloadPID = $Process.ProcessId - -while ($true) { - if (Test-Path $TargetDir) { - # Using .NET for much faster file enumeration than Get-ChildItem -Recurse - [long]$CurrentSize = 0 - try { - $files = [System.IO.Directory]::EnumerateFiles($TargetDir, "*", [System.IO.SearchOption]::AllDirectories) - foreach ($file in $files) { - $CurrentSize += (New-Object System.IO.FileInfo($file)).Length - } - } - catch { - # Directory might be busy or locked - } - - $Percent = [Math]::Min(100, ($CurrentSize / $TargetSizeFactor) * 100) - $CurrentGB = [Math]::Round($CurrentSize / 1GB, 2) - - $TimeElapsed = (Get-Date) - $StartTime - if ($CurrentSize -gt 0 -and $TimeElapsed.TotalSeconds -gt 0) { - $TotalEstimatedTime = New-TimeSpan -Seconds ($TimeElapsed.TotalSeconds * ($TargetSizeFactor / $CurrentSize)) - $RemainingTime = $TotalEstimatedTime - $TimeElapsed - $Speed = [Math]::Round(($CurrentSize / 1MB) / $TimeElapsed.TotalSeconds, 2) - } - else { - $RemainingTime = New-TimeSpan -Seconds 0 - $Speed = 0 - } - - $ProgressBar = "[" + ("#" * [Math]::Floor($Percent / 5)) + ("." * (20 - [Math]::Floor($Percent / 5))) + "]" - - # Format ETA as HH:MM:SS - $etaStr = if ($RemainingTime.TotalSeconds -gt 0) { - "$([int]$RemainingTime.TotalHours):$($RemainingTime.Minutes.ToString('00')):$($RemainingTime.Seconds.ToString('00'))" - } - else { "Calculating..." } - - $Status = "$ProgressBar $($Percent.ToString("F2"))% | $CurrentGB / $TargetSizeGB GB | Speed: $Speed MB/s | ETA: $etaStr" - Write-Progress -Activity "Downloading HEST Bowel Cancer Subset" -Status $Status -PercentComplete $Percent - - # Also print to console - Write-Host "`r$Status" -NoNewline - } - else { - Write-Host "`rWaiting for target directory to be created..." -NoNewline - } - - # Check if process is still running - $CurrentProc = Get-CimInstance Win32_Process -Filter "ProcessId = $DownloadPID" - if ($null -eq $CurrentProc -and $Percent -lt 99) { - Write-Host "`nWARNING: Download process (PID $DownloadPID) seems to have stopped unexpectedly!" -ForegroundColor Red - Write-Host "Check the terminal output above for any Python errors." -ForegroundColor Yellow - break - } - - if ($Percent -ge 100) { - Write-Host "`nDownload complete!" -ForegroundColor Green - break - } - - Start-Sleep -Seconds 5 -} diff --git a/LICENSE b/LICENSE index 804048a..36665f8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -PROPRIETARY SOURCE CODE LICENSE (NON-COMMERCIAL + NEGOTIATED COMMERCIAL) +# PROPRIETARY SOURCE CODE LICENSE (NON-COMMERCIAL + NEGOTIATED COMMERCIAL) Copyright (c) 2026 Benjamin Isaac Wilson. All rights reserved. diff --git a/README.md b/README.md index e57d9d2..b9dbd8e 100644 --- a/README.md +++ b/README.md @@ -36,40 +36,23 @@ This project requires [Conda](https://docs.conda.io/en/latest/). ## Usage -**Before running any commands**, you must activate the conda environment: +### Dataset Access -```bash -conda activate SpatialTranscriptFormer -``` - -### Download HEST Data - -> [!CAUTION] -> **Authentication Required**: The HEST dataset is gated. You must accept the terms of use at [MahmoodLab/hest](https://huggingface.co/datasets/MahmoodLab/hest) and authenticate with your Hugging Face account to download the data. - -Please provide your token using ONE of the following methods before running the download tool: - -1. **Persistent Login**: Run `huggingface-cli login` and paste your access token when prompted. -2. **Environment Variable**: Set the `HF_TOKEN` environment variable in your active terminal session. - -Once authenticated, download specific subsets using filters or the entire dataset: +The model uses the **HEST1k** dataset. You can download specific subsets (by organ, technology, etc.) or the entire dataset using the `stf-download` utility: ```bash -# Option 1: Download the ENTIRE HEST dataset (requires confirmation) -stf-download --local_dir hest_data +# List available filtering options +stf-download --list-options -# Option 2: Download a specific subset (e.g., Bowel Cancer) -stf-download --organ Bowel --disease Cancer --local_dir hest_data +# Download a specific subset (e.g., Breast Cancer samples from Visium) +stf-download --organ Breast --disease Cancer --tech Visium --local_dir hest_data -# Option 3: Filter by technology (e.g., Visium) -stf-download --tech Visium --local_dir hest_data +# Download all human samples +stf-download --species "Homo sapiens" --local_dir hest_data ``` -To see all available organs in the metadata: - -```bash -stf-download --list_organs -``` +> [!NOTE] +> The HEST dataset is gated on Hugging Face. Ensure you have accepted the terms at [MahmoodLab/hest](https://huggingface.co/datasets/MahmoodLab/hest) and are logged in via `huggingface-cli login`. ### Train Models @@ -122,6 +105,13 @@ Visualization plots will be saved to the `./results` directory. .\test.ps1 ``` +## Future Directions & Clinical Collaborations + +A major future direction for **SpatialTranscriptFormer** is to integrate this architecture into an **end-to-end pipeline for patient risk assessment** and prognosis tracking. By leveraging the model's predicted expression and pathway activations, we aim to build a downstream risk prediction module that allows users to directly evaluate how spatially-resolved expression relates to patient survival. + +> [!NOTE] +> **Call for Collaborators:** Rigorous risk assessment models require vast datasets of clinical metadata and survival outcomes, which we currently lack access to. We are open to investigating *any* disease of interest! If you have access to large clinical cohorts and are interested in exploring how spatial pathway activation correlates with patient prognosis, we would love to partner with you. + ## Contributing We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on our coding standards and the process for submitting pull requests. Note that this project is under a proprietary license; contributions involve an assignment of rights for non-academic use. diff --git a/docs/IP_STATEMENT.md b/docs/IP_STATEMENT.md index d007952..61c03e3 100644 --- a/docs/IP_STATEMENT.md +++ b/docs/IP_STATEMENT.md @@ -14,6 +14,18 @@ The primary innovation is the **multimodal bottleneck transformer** designed for - **Quadrant-Based Interaction Masking**: The logic used to zero out specific attention quadrants (e.g., $A_{H \to H}$) to optimize memory while maintaining multimodal context. - **Biologically-Informed Reconstruction Bottleneck**: The specific matrix decomposition approach where gene expression is reconstructed from a linear combination of pathway activations. +### Proposed Auxiliary Pathway Loss + +To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership. + +The total objective becomes: +$$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$ + +The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE. + +The full training objective with pathway sparsity regularisation: +$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$ + ## 2. Spatial Context Methodologies - **Euclidean-Gated Attention**: The implementation of spatial distance-based masking ($M_{spatial}$) to constrain model focus to local morphological regions. diff --git a/docs/MODELS.md b/docs/MODELS.md index 0b9dc54..3029aa0 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -122,7 +122,7 @@ Together, these ensure the model learns *spatially-varying* pathway activation m #### Frozen Backbone (Feature Extraction) -Pre-computed features from a pathology foundation model. The backbone is never fine-tuned. +Pre-computed features from a pathology foundation model. (The backbone is never fine-tuned, though this might change!) | Backbone | Feature Dim | Source | | :--- | :--- | :--- | @@ -214,7 +214,7 @@ The Zero-Inflated Negative Binomial (ZINB) loss is designed for raw, highly disp The model outputs these parameters, and the loss computes the negative log-likelihood of the ground truth counts given this distribution. -### Auxiliary Pathway Loss +### Proposed Auxiliary Pathway Loss To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership. diff --git a/docs/PATHWAY_MAPPING.md b/docs/PATHWAY_MAPPING.md index 658899a..45ae2dc 100644 --- a/docs/PATHWAY_MAPPING.md +++ b/docs/PATHWAY_MAPPING.md @@ -21,49 +21,55 @@ After the model makes predictions (N spots x G genes), we run a statistical test - **Tool**: `gseapy` or a custom mapping script. - **Use Case**: Generating a "Pathway Activation Map" from a trained model's output. -### B. Pathway Bottleneck (Model Architecture) +### B. Interaction Model via Multi-Task Learning (MTL) -The **SpatialTranscriptFormer** replaces the standard linear output head with a two-step projection that can be configured in two modes: +The **SpatialTranscriptFormer** interaction model inherently represents pathway activations as part of its attention mechanism and output process. Rather than a simple linear bottleneck, it utilizes learnable pathway tokens and Multi-Task Learning (MTL). -#### 1. Informed Projection (Prior Knowledge) +#### 1. Informed Supervision via Auxiliary Loss -In this mode, the **Gene Reconstruction Matrix** $\mathbf{W}_{recon}$ is guided by established biological databases (MSigDB, KEGG). +In this mode, the network receives direct supervision on its pathway tokens, guided by established biological databases (e.g., MSigDB): -- **Implementation**: $\mathbf{W}_{recon}$ is initialized as a binary mask $M \in \{0, 1\}^{G \times P}$ where $M_{gk} = 1$ if gene $g$ belongs to pathway $k$. -- **Benefit**: Predictions are guaranteed to be linear combinations of known biological processes, making them instantly interpretable by clinicians. +- **Architecture Flow**: + 1. **Interaction**: Learnable pathway tokens $P$ interact with Histology patch features $H$ via self-attention (e.g., $p2h$, $h2p$). + 2. **Activation**: Pathway scores $S \in \mathbb{R}^P$ are computed using a learnable temperature-scaled cosine similarity between the pathway tokens and image patch tokens. + 3. **Gene Reconstruction**: $\hat{y} = S \cdot \mathbf{W}_{recon} + b$, where $\mathbf{W}_{recon}$ is initialized using the binary pathway membership matrix $M$. +- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth ($Y_{genes} \cdot M^T$) using a Pearson Correlation Coefficient (PCC) loss. + $$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, Y_{genes} \cdot M^T))$$ +- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability. -#### 2. Data-Driven Projection (Latent Discovery) +#### 2. Data-Driven Discovery (Latent Projection) -In this mode, the model learns its own "latent pathways" based on morphological patterns. +In the absence of a biological prior, the model can learn its own "latent pathways". -- **Implementation**: $\mathbf{W}_{recon}$ is randomly initialized and learned via backpropagation. -- **Sparsity Constraint**: We apply an L1 penalty to force the model to identify "canonical" gene sets: $L_{total} = L_{MSE} + \lambda \|\mathbf{W}_{recon}\|_1$. +- **Implementation**: $\mathbf{W}_{recon}$ is randomly initialized and the auxiliary pathway loss is disabled. +- **Sparsity Constraint**: We apply an L1 penalty to force the model to identify "canonical" sparse gene sets: $L_{total} = L_{gene} + \lambda_{sparsity} \|\mathbf{W}_{recon}\|_1$. - **Benefit**: Can discover novel spatial-transcriptomic relationships that aren't yet captured in curated databases. -- **Architecture Flow**: - 1. **Interaction**: Pathway tokens $P$ query the Histology $H$. - 2. **Activation**: A linear layer reduces $P_{tokens}$ to activation scores $S \in \mathbb{R}^P$. - 3. **Reconstruction**: $\hat{y} = S \cdot \mathbf{W}_{recon} + b$. +## 3. Generalizing to HEST1k Tissues + +The model supports any dataset within the HEST1k collection (e.g., Breast, Kidney, Lung, Colon). Instead of being bound to a single disease context, users can leverage the `--custom-gmt` flag to map genes to pathways relevant to their specific investigation. -## 3. Clinical Application in Bowel Cancer +### Example: Profiling the Tumor Microenvironment -For colorectal cancer, we should prioritize monitoring these specific pathways: +Regardless of the tissue of origin (e.g., Kidney versus Breast), researchers often track core functional states within the tumor microenvironment. A user might define a `.gmt` file to explicitly monitor: -| Pathway | Clinically Relevant Genes | Clinical Significance | +| Pathway Concept | Hallmarks / Relevant Genes | Interpretive Value across Tissues | | :--- | :--- | :--- | -| **Wnt Signaling** | `CTNNB1`, `MYC`, `AXIN2` | Common driver in CRC (APC mutations) | -| **MMR / DNA Repair** | `MLH1`, `MSH2`, `MSH6` | MSI vs MSS status (Immunotherapy response) | -| **EMT** | `SNAI1`, `VIM`, `ZEB1` | Tumor invasion and metastasis risk | -| **Angiogenesis** | `VEGFA`, `FLT1` | Potential for anti-angiogenic therapy | +| **Hypoxia & Angiogenesis** | `VEGFA`, `FLT1`, `HIF1A` | Identifies oxygen-deprived or highly vascularized tumor cores. | +| **Immune Infiltration** | `CD8A`, `GZMB`, `IFNG` | Maps regions of active anti-tumor immune response. | +| **Stromal / EMT** | `VIM`, `SNAI1`, `ZEB1` | Highlights desmoplastic stroma and invasion fronts. | +| **Proliferation** | `MKI67`, `PCNA`, `MYC` | Pinpoints highly active, dividing cell populations. | + +By supplying these functional groupings via `--custom-gmt`, the model's MTL process explicitly aligns its spatial interaction tokens to monitor these exact states across any whole-slide image in the HEST1k dataset. ## 4. Implementation Status ### Implemented - **MSigDB Hallmarks Initialization** (`--pathway-init` flag): Downloads the GMT file, matches genes against `global_genes.json`, and initializes `gene_reconstructor.weight` with the binary membership matrix. See [`pathways.py`](../src/spatial_transcript_former/data/pathways.py). - - 50 Hallmark pathways (fixed when using `--pathway-init`) - - ~54% gene coverage (542/1000 genes mapped to at least one pathway) - - GMT file cached in `.cache/` after first download + - 50 Hallmark pathways (default fixed fallback when using `--pathway-init`). + - GMT file cached in `.cache/` after first download. +- **Custom Pathway Definitions** (`--custom-gmt` flag): Users can override the default Hallmarks by providing a URL or local path to a `.gmt` file, enabling custom database integrations (e.g., KEGG, Reactome, or highly specific tissue masks). - **Sparsity Regularization** (`--sparsity-lambda` flag): L1 penalty on `gene_reconstructor` weights to encourage pathway-like groupings when using data-driven (random) initialization. @@ -79,8 +85,9 @@ python -m spatial_transcript_former.train \ --model interaction --num-pathways 50 --sparsity-lambda 0.01 ... ``` +- **Spatial Pathway Maps**: Visualize pathway activations as spatial heatmaps overlaid on histology using `stf-predict`. See the [README](../README.md) for inference instructions. + ### Future Work -- **KEGG/Reactome**: More granular pathway databases for finer-grained analysis. -- **Post-Hoc Enrichment**: `gseapy` integration for pathway activation maps from model outputs. -- **Spatial Pathway Maps**: Visualize pathway activations as spatial heatmaps overlaid on histology. +- **Post-Hoc Enrichment**: `gseapy` integration for pathway activation maps from model outputs without architectural bottlenecks. +- **End-to-End Risk Assessment Module**: Developing a downstream prediction system that takes the spatially-resolved pathway activations and gene expressions derived from the model and maps them directly to clinical risk and survival outcomes. diff --git a/docs/TESTING.md b/docs/TESTING.md index 73f118a..255bda1 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -29,8 +29,26 @@ Or using the provided PowerShell script: - Patient-level splitting (train/val/test). - Leakage prevention (ensuring patients don't overlap between splits). -## Adding New Tests +## Contributor Guidelines: Adding New Tests When adding new functionality, please add corresponding tests in the `tests/` directory. -- Use `unittest` or `pytest` style tests. -- Mock external calls (like `huggingface_hub` or large file I/O) to keep tests fast and offline-capable where possible. + +- **Framework**: Use `unittest` or `pytest` style tests. +- **Mocking**: Mock external calls (like `huggingface_hub` or large file I/O) to keep tests fast and offline-capable where possible. +- **Discussion**: We are always looking for ways to improve our testing practices! If you have ideas for better test architecture, coverage strategies, or tooling, please feel free to open a discussion or issue. + +### Mutation Testing + +While standard unit tests ensure the code behaves as expected under specific conditions, they don't always guarantee the robustness of the tests themselves. + +We strongly encourage **Mutation Testing** when contributing critical components. Mutation testing introduces small changes (mutations) into the source code and checks if your tests catch them (by failing). If a test still passes despite the mutated code, our tests may not be tight enough! + +Our preferred method for mutation testing in Python is **[cosmic-ray](https://github.com/sixty-north/cosmic-ray)**. + +To get started with `cosmic-ray`: + +1. Install it via pip: `pip install cosmic-ray` +2. Initialize a configuration file for your module. +3. Run the mutation tests and review the survival report to strengthen your test suites. + +If you find ways to automate or better integrate mutation testing into our CI pipeline, we would welcome those discussions! diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 3b0d9c0..5a5e923 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -14,7 +14,7 @@ conda activate SpatialTranscriptFormer ## 1. Single Patch Regression (Baselines) -Predicts gene expression for a single 224x224 patch. +Predicts gene expression for a single 224x224 patch. No cross attention interactions between patches or pathways. ### HE2RNA (ResNet50) @@ -126,6 +126,27 @@ python -m spatial_transcript_former.train \ > **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download. +### Data-Driven Discovery (Latent Pathways) + +To allow the model to discover its own spatial-transcriptomic relationships without biological priors, omit `--pathway-init` and apply sparsity regularization (`--sparsity-lambda`). This aims to force the model to identify "canonical" sparse gene sets. + +```bash +python -m spatial_transcript_former.train \ + --data-dir A:\hest_data \ + --model interaction \ + --backbone ctranspath \ + --use-nystrom \ + --num-pathways 50 \ + --sparsity-lambda 0.01 \ + --precomputed \ + --whole-slide \ + --use-amp \ + --log-transform \ + --epochs 100 +``` + +> **Note**: Without `--pathway-init`, the model disables the `AuxiliaryPathwayLoss` and relies entirely on the main reconstruction objectives and the L1 sparsity penalty. (I am yet to obtain results with this method)... + ### Robust Counting: ZINB + Auxiliary Loss For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended. @@ -222,6 +243,7 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime | `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. | | `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. | | `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. | +| `--sparsity-lambda` | L1 regularization weight for discovering latent pathways. | Use `0.01` when `--pathway-init` is NOT used. | | `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). | | `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. | | `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. | diff --git a/scripts/run_preset.py b/scripts/run_preset.py index 8e51d2b..fe3e423 100644 --- a/scripts/run_preset.py +++ b/scripts/run_preset.py @@ -105,7 +105,7 @@ "--batch-size", "4", "--epochs", - "2000", + "2500", "--use-amp", "--loss", "zinb", @@ -120,27 +120,8 @@ "--plot-pathways", "--resume", ], - # Legacy / Specialized - "stf_pathway_nystrom": STF_COMMON - + [ - "--sparsity-lambda", - "0.05", - "--lr", - "1e-4", - "--batch-size", - "8", - "--epochs", - "2000", - ], } -# Add alias for the one currently running to ensure backward compatibility for монитор -PRESETS["stf_interaction_mse_pcc"] = PRESETS["stf_interaction_l2"] + [ - "--pathway-loss-weight", - "0.5", - "--plot-pathways", -] - def main(): parser = argparse.ArgumentParser( diff --git a/src/spatial_transcript_former/data/pathways.py b/src/spatial_transcript_former/data/pathways.py index d5d60a1..0608052 100644 --- a/src/spatial_transcript_former/data/pathways.py +++ b/src/spatial_transcript_former/data/pathways.py @@ -14,7 +14,6 @@ # MSigDB collections URLs (v2024.1.Hs, gene symbols) MSIGDB_URLS = { "hallmarks": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/h.all.v2024.1.Hs.symbols.gmt", - "c2_kegg": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cp.kegg_legacy.v2024.1.Hs.symbols.gmt", "c2_medicus": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cp.kegg_medicus.v2024.1.Hs.symbols.gmt", "c2_cgp": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cgp.v2024.1.Hs.symbols.gmt", } @@ -110,7 +109,7 @@ def get_pathway_init( tuple: (membership_matrix [Tensor (P, G)], pathway_names [list of str]) """ if gmt_urls is None: - gmt_urls = [MSIGDB_URLS["hallmarks"], MSIGDB_URLS["c2_kegg"]] + gmt_urls = [MSIGDB_URLS["hallmarks"]] combined_dict = {} diff --git a/src/spatial_transcript_former/data/utils.py b/src/spatial_transcript_former/data/utils.py index 4908f3d..77e6a05 100644 --- a/src/spatial_transcript_former/data/utils.py +++ b/src/spatial_transcript_former/data/utils.py @@ -42,18 +42,7 @@ def get_sample_ids( df_human = df_filtered[df_filtered["species"] == "Homo sapiens"] human_ids = df_human["id"].tolist() - # Filter for Human Bowel - df_bowel = df_human[ - df_human["organ"].str.contains("Bowel", case=False, na=False) - ] - - if not df_bowel.empty: - print(f"Filtering for Human Bowel samples ({len(df_bowel)} found)...") - final_ids = df_bowel["id"].tolist() - else: - print("No Human Bowel samples found, falling back to all Human samples...") - final_ids = human_ids - + final_ids = human_ids if not final_ids: print("Warning: No Human samples found. Using all files.") final_ids = all_ids diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index f8783cf..36cd0da 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -120,8 +120,8 @@ def plot_histology_overlay( plt.close("all") -# Fixed bowel-cancer-relevant pathways (MSigDB Hallmark names, without prefix) -BOWEL_CANCER_PATHWAYS = [ +# Fixed representative pathways for visualization (MSigDB Hallmark names) +CORE_PATHWAYS = [ "EPITHELIAL_MESENCHYMAL_TRANSITION", "WNT_BETA_CATENIN_SIGNALING", "INFLAMMATORY_RESPONSE", @@ -176,7 +176,7 @@ def plot_training_summary( if short in name_to_idx: display_pathways.append((short, name_to_idx[short])) else: - for pw in BOWEL_CANCER_PATHWAYS: + for pw in CORE_PATHWAYS: if pw in name_to_idx: display_pathways.append((pw, name_to_idx[pw])) @@ -394,9 +394,6 @@ def main(): new_state_dict[k[len("_orig_mod.") :]] = v else: new_state_dict[k] = v - # Handle legacy checkpoints missing keys? No, usually extra keys are issue or missing - # But if we added components (pathways?) no they are same architecture just different forward. - model.load_state_dict(new_state_dict) model.to(device) model.eval() diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index b24c4f5..70c5fb1 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -66,11 +66,12 @@ def setup_model(args, device): with open(genes_path) as f: gene_list = json.load(f) - if getattr(args, "pathways", None): - # If specific pathways requested, search all collections + if getattr(args, "custom_gmt", None): + urls = args.custom_gmt + elif getattr(args, "pathways", None): + # If specific pathways requested but no custom GMT, search standard collections urls = [ MSIGDB_URLS["hallmarks"], - MSIGDB_URLS["c2_kegg"], MSIGDB_URLS["c2_medicus"], MSIGDB_URLS["c2_cgp"], ] @@ -196,21 +197,14 @@ def load_checkpoint(model, optimizer, scaler, output_dir, model_name, device): print(f"Resuming from {ckpt_path}...") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) - if "model_state_dict" in checkpoint: - model.load_state_dict(checkpoint["model_state_dict"]) - if "optimizer_state_dict" in checkpoint: - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - if "scaler_state_dict" in checkpoint and scaler is not None: - scaler.load_state_dict(checkpoint["scaler_state_dict"]) + model.load_state_dict(checkpoint["model_state_dict"]) + if "optimizer_state_dict" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if "scaler_state_dict" in checkpoint and scaler is not None: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) - start_epoch = checkpoint.get("epoch", -1) + 1 - best_val_loss = checkpoint.get("best_val_loss", float("inf")) - else: - # Legacy checkpoint (raw state dict) - model.load_state_dict(checkpoint) - start_epoch = 0 - best_val_loss = float("inf") - print("Loaded weights only (legacy checkpoint).") + start_epoch = checkpoint.get("epoch", -1) + 1 + best_val_loss = checkpoint.get("best_val_loss", float("inf")) print(f"Resumed at epoch {start_epoch + 1}") return start_epoch, best_val_loss @@ -346,7 +340,13 @@ def parse_args(): "--pathways", nargs="+", default=None, - help="List of MSigDB pathway names to explicitly instantiate (e.g. KEGG_COLORECTAL_CANCER)", + help="List of MSigDB pathway names to explicitly instantiate (e.g. HALLMARK_APOPTOSIS). If none are provided but --pathway-init is enabled, all pathways in the provided GMTs will be loaded.", + ) + g.add_argument( + "--custom-gmt", + nargs="+", + default=None, + help="List of URLs or local paths to custom .gmt files for pathway initialization. Overrides standard MSigDB defaults if provided.", ) return parser.parse_args() diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index 1e3b367..2aaef7a 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -70,12 +70,14 @@ def _compute_pathway_truth(gene_truth, gene_names, args=None): filter_names = None urls = None if args is not None and getattr(args, "pathway_init", False): - urls = [ - MSIGDB_URLS["hallmarks"], - MSIGDB_URLS["c2_kegg"], - MSIGDB_URLS["c2_medicus"], - MSIGDB_URLS["c2_cgp"], - ] + if getattr(args, "custom_gmt", None): + urls = args.custom_gmt + else: + urls = [ + MSIGDB_URLS["hallmarks"], + MSIGDB_URLS["c2_medicus"], + MSIGDB_URLS["c2_cgp"], + ] filter_names = getattr(args, "pathways", None) pw_matrix, pw_names = get_pathway_init( @@ -120,7 +122,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): """ Run inference on a single sample and save a unified pathway visualization. - Produces a single figure per epoch showing histology + fixed bowel-cancer + Produces a single figure per epoch showing histology + core pathways (ground truth vs prediction), where ground truth is computed by projecting true gene expression through the model's gene_reconstructor via pseudo-inverse so both live in the same activation space. diff --git a/tests/test_pathways.py b/tests/test_pathways.py index 4095971..958498e 100644 --- a/tests/test_pathways.py +++ b/tests/test_pathways.py @@ -113,8 +113,8 @@ def test_pathway_count(self, pathway_result): _, names = pathway_result assert len(names) == 50 - def test_bowel_cancer_pathways_exist(self, pathway_result): - """All 6 disease-relevant pathways should be in the names list.""" + def test_core_pathways_exist(self, pathway_result): + """All 6 representative pathways should be in the names list.""" _, names = pathway_result short_names = [n.replace("HALLMARK_", "") for n in names] required = [ @@ -149,15 +149,15 @@ def test_consistent_across_calls(self, gene_list): assert names1 == names2 def test_output_shape(self, gene_list): - """Pathway truth should be (N, P) where P=236 (Hallmarks + C2 KEGG).""" + """Pathway truth should be (N, P) where P=50 (Hallmarks default).""" from spatial_transcript_former.visualization import _compute_pathway_truth N = 150 gene_truth = np.random.rand(N, len(gene_list)).astype(np.float32) result, names = _compute_pathway_truth(gene_truth, gene_list) - assert result.shape == (N, 236) - assert len(names) == 236 + assert result.shape == (N, 50) + assert len(names) == 50 def test_spatial_variation(self, gene_list): """Pathway truth should have spatial variation (non-zero std).""" diff --git a/tests/test_visualization.py b/tests/test_visualization.py index e050a83..c149232 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -12,7 +12,7 @@ matplotlib.use("Agg") from spatial_transcript_former.predict import ( - BOWEL_CANCER_PATHWAYS, + CORE_PATHWAYS, plot_training_summary, ) @@ -58,16 +58,16 @@ def mock_data(pathway_names): # --------------------------------------------------------------------------- -class TestBowelCancerPathways: +class TestRepresentativePathways: def test_all_pathways_exist_in_msigdb(self, pathway_names): - """All 6 bowel cancer pathways should be in the MSigDB Hallmarks.""" + """All 6 representative pathways should be in the MSigDB Hallmarks.""" short_names = [n.replace("HALLMARK_", "") for n in pathway_names] - for pw in BOWEL_CANCER_PATHWAYS: + for pw in CORE_PATHWAYS: assert pw in short_names, f"Missing: {pw}" def test_pathway_count(self): """Should have exactly 6 fixed pathways.""" - assert len(BOWEL_CANCER_PATHWAYS) == 6 + assert len(CORE_PATHWAYS) == 6 # ---------------------------------------------------------------------------