Bayesian extension of scVI for single-cell/nucleus RNA-seq and multiome (RNA+ATAC) data integration. Built on cell2location/cell2fate modelling principles (Kleshchevnikov et al. 2022, Aivazidis et al. 2025). Adds structural inductive biases — ambient RNA correction, hierarchical dispersion prior, batch-free decoder, learned library size — that make high-capacity models (n_hidden=512+, n_latent=128+) well-behaved without substantial per-dataset hyperparameter tuning. Designed for complex datasets with hundreds of cell types (whole-embryo atlases, cross-atlas integration).
| Model class | Module | File | Purpose |
|---|---|---|---|
| AmbientRegularizedSCVI | RegularizedVAE | _module.py (1372L) |
Single-modal RNA |
| RegularizedMultimodalVI | RegularizedMultimodalVAE | _multimodule.py (1843L) |
Multi-modal RNA+ATAC |
Supporting: _model.py (942L), _multimodel.py (1571L), _components.py (472L), _constants.py (78L).
ambient_covariate_keys— additive background per batchnn_conditioning_covariate_keys— decoder categorical injectionfeature_scaling_covariate_keys— multiplicative scalingdispersion_key— overdispersion groupslibrary_size_key— library size prior groupsencoder_covariate_keys— encoder injection (default False, matching scVI/MultiVI/PeakVI)batch_keyalone fans out toambient_covariate_keyslibrary_size_keydispersion_key(backward compatibility)
library_key— finest technical unit (sequencing run, lane, GEM well). In_model.py,batch_keyis semantically equivalent tolibrary_key— it is the finest technical unit, fanning out toambient_covariate_keys,library_size_key,dispersion_key.dataset_key— groups of libraries from the same study. Optional mid-level grouping; eachlibrary_keyvalue must map to exactly onedataset_keyvalue (validated at setup).technical_covariate_keys— broad technical axes (embryo, experiment type, 10x kit). Optional, non-hierarchical; may have multiple values specified.- New code should use purpose-based keys instead of the generic
batch_key. The model's existingambient_covariate_keys,dispersion_key,library_size_keyare all finest-unit groupings — aligning withlibrary_keyterminology. - Graceful degradation: when only
library_keyis provided, multi-level comparisons (cross-dataset, cross-technical) become empty but do not raise errors. - Backward compatibility: existing code (
_model.pybatch_key,_integration_metrics.pybatch_key) retains current semantics. New neighbourhood correlation metrics modulesrc/regularizedvi/plt/_neighbourhood_correlation.pyuseslibrary_key,dataset_key,technical_covariate_keysexclusively. - Curated marker genes:
docs/notebooks/known_marker_genes.csv(~192 genes, columns: gene, cell_type, lineage, category).
- Ambient RNA: Additive background s_e,g = exp(beta) with Gamma(1,100) prior — per gene/batch (by default not regularised with prior)
- Feature scaling: y_t,g = softplus(gamma)/0.7 with Gamma(200,200) prior — multiplicative bias per covariate group (by default regularised with prior)
- Hierarchical dispersion: Variational LogNormal posterior, containment prior 1/sqrt(theta) ~ Exp(lambda), lambda ~ Gamma(9,3)
- Learned library size: With 0.5 variance scaling on the prior
- Observation model: GammaPoisson (=NB), softplus activation (not softmax) since rho + s need not sum to 1
Per-modality encoders, concatenated latent space [z_atac; z_rna] (alphabetical sort of modality_names), all decoders see full z. Decoders are symmetrical with options to switch off purpose-based covariates (unlike MultiVI).
- Computes Jacobian of decoder output w.r.t. each modality's latent dimensions using finite differences
- Per-cell attribution scores: ||J_rna||_F vs ||J_atac||_F (Frobenius norm of per-modality Jacobian blocks)
plot_attribution_scatter(): convenience method for UMAP-colored attribution visualization
get_latent_representation(): returns posterior mean z (or sampled z) per cellget_normalized_expression(): returns denoised expression (decoder output, library-size normalized)get_modality_attribution(): Modality attribution via Jacobian analysis- Both support
batch_sizefor memory-efficient inference on large datasets
- Input x (log-library-normalized) → z_encoder → qz (mean, var) → z sample
- x → l_encoder → ql (library mean, var) → library sample
- Returns dict:
z,qz,ql,library - Continuous covariates always concatenated to encoder input
- Categorical covariates to encoder only if
encoder_covariate_keysexplicitly set (default False)
- z → decoder → px_rate (rho, unnormalized gene expression)
- Dispersion: sample theta from LogNormal(px_r_mu[group], px_r_log_sigma[group]) per cell
- Feature scaling: y_t,g = softplus(gamma[group])/0.7, applied multiplicatively
- Ambient RNA: s_e,g = exp(beta[batch]), added to rate
- Final rate:
px_rate = (rho + s) * y * library(softplus, not softmax — no sum-to-1 constraint) - Likelihood: GammaPoisson(theta, px_rate) — equivalent to NegativeBinomial
- Reconstruction: -log p(x|z) via GammaPoisson log-prob
- KL(q(z)||p(z)): standard VAE latent KL
- KL(q(l)||p(l)): library size KL with 0.5 variance scaling
- Hierarchical dispersion penalty: KL between variational LogNormal and containment prior
- Background penalty: Gamma(1,100) prior on exp(beta) — keeps ambient small
- Feature scaling penalty: Gamma(200,200) prior on softplus(gamma)/0.7 — keeps scaling near 1
- All penalties logged via
extra_metrics→compute_and_log_metrics()→model.history_
loss()takes an explicitn_obsargument = full training-set size, injected automatically byTrainingPlan.n_obs_trainingsetter via signature introspection (signature(module.loss).parameters). Validation also usesn_obs_training(notn_val) so train/val losses are on the same scale — scvi-tools convention (_trainingplans.py:356-358).- Local (cell-plate) terms — reconstruction loss,
KL(qz‖pz),KL(ql‖pl), z-sparsity, horseshoe KL, hidden-activation sparsity — are summed over non-batch dims and meaned over the batch axis inside the maintorch.mean(...). - Global (gene-plate / batch-plate / covariate-plate / plate-less) priors — dispersion variational KL + λ hyperprior, ambient RNA β, feature scaling γ, decoder L1/L2, ARD on z, modality scaling, residual library
wKL — are added to the loss aspenalty / n_obswheren_obsis theloss()argument (= N_train), NEVERrecon_loss.shape[0]orx.shape[0](minibatch size). loss()assertsn_obs >= batch_size(overridable viaskip_n_obs_check=True) to catch missing injection at train time. Unit tests that callmodule.loss(...)directly must passskip_n_obs_check=True.- Historical bug (fixed 2026-04-11): prior to this fix, all global priors used
n_obs = recon_loss.shape[0]andkl_w_totalwas added raw, over-weighting every prior by ~B² per epoch (B = n_minibatches). Sweep results from before this fix cannot be directly compared to post-fix results.
- RegularizedFCLayers: dropout applied to INPUT (not output), LayerNorm default (not BatchNorm), configurable activation
- RegularizedEncoder: FCLayers → (mean_encoder, var_encoder) linear heads → Normal distribution
- RegularizedDecoderSCVI: FCLayers → px_scale_decoder linear head → softplus activation
bash scripts/helper_scripts/run_tests.sh tests/test_model.py -x -q(114 tests)- Pre-commit: pyproject-fmt, ruff check, ruff format (auto-fix)
- Python >=3.11
_model.pypasses kwargs BOTH via_module_kwargsdict AND explicit constructor — add new params to BOTH_multimodel.pyuses**kwargsfrom_module_kwargs— adding to dict is sufficientextra_metricsinloss()->compute_and_log_metrics()->model.history_as{key}_{train|validation}- Papermill cannot parse parameter lines with inline comments — use bare assignments only
batch_representation="one-hot"required (embedding incompatible with per-batch ambient RNA)use_feature_scaling=True(default) creates (1,n_genes) fallback param even without covariatesloss()requiresn_obskwarg (= N_train from TrainingPlan). Direct test calls must passskip_n_obs_check=True; NEVER userecon_loss.shape[0]/x.shape[0]to normalize global priors.
GPU experiment specs in _gpu_jobs.yaml (~20+ experiments on NeurIPS 2021 adult bone marrow multiome). Testing: library centering, library prior variance, early stopping sensitivity, stratified validation, learnable modality scaling, ATAC filtering thresholds, per-modality learning rates.
- Neighbourhood correlation metrics:
.claude/plans/neighbourhood_correlation_plan.md— per-cell marker gene correlation with KNN neighbours, label-free. Stratified by library/dataset/technical.
docs/notebooks/immune_integration/ — 7-dataset multi-site study (706k cells after QC): bone marrow, TEA-seq PBMC, NEAT-seq CD4, Crohn's PBMC, COVID infant PBMC, lung/spleen, infant/adult spleen. 7 notebooks: data loading -> scrublet -> ATAC loading -> RNA training -> annotation -> CRE selection -> multimodal training.
When a plan reaches its final step (after all implementation/commits but before any irreversible action like GPU job submission or git push), invoke the /verify-implementation skill. Do NOT launch an ad-hoc Agent subagent for this — use the skill specifically. It runs 7 parallel specialist audits (plan completion, post-plan user input, math-code matching, code structure, cross-file consistency, project-specific checks, notebook pre-submission) and produces a structured report. Only proceed past verification once the skill returns PASS (or after each finding has been acknowledged).
Subagents (Plan, Explore, etc.) that lack Write/Edit tools must NEVER use Bash heredocs (cat > file << EOF, echo >) to create files. Instead, they must return the file content in their response text, and the parent agent must use Write/Edit to create the file. When launching subagents that need to produce files (plans, scripts, configs), use subagent_type: "general-purpose" which has Write/Edit access, or handle file creation in the parent after the subagent returns.
bsub -q gpu-normal -n 8 -M 40000 -R"select[mem>40000] rusage[mem=40000] span[hosts=1]" -gpu "mode=shared:j_exclusive=yes:gmem=80000:num=1" -e ./%J.gpu.err -o ./%J.gpu.out -J <job_name> 'PYTHONNOUSERSITE=TRUE module load ISG/conda && conda activate regularizedvi && papermill <input.ipynb> <output.ipynb>'
| Dataset | Cells | Features | Modalities | MAX MEM | Request | Queue |
|---|---|---|---|---|---|---|
| Immune RNA | 416k | 20k genes | 1 | 30 GB | 60 GB | gpu-normal |
| Bone marrow | 35k | 13k + 116k | 2 (RNA+ATAC) | ~25 GB | 40 GB | gpu-normal |
| Embryo | 424k | 28k + 20k + 23k + 342k | 4 (RNA+spliced+unspliced+ATAC) | 187 GB | 300 GB | gpu-huge, -sp 99 |