|
908 | 908 | "tags": [] |
909 | 909 | }, |
910 | 910 | "outputs": [], |
911 | | - "source": [ |
912 | | - "model = regularizedvi.RegularizedMultimodalVI(\n", |
913 | | - " mdata,\n", |
914 | | - " # Per-modality architecture sizes\n", |
915 | | - " n_hidden={\"rna\": 512, \"atac\": 256},\n", |
916 | | - " n_latent={\"rna\": 128, \"atac\": 64},\n", |
917 | | - " n_layers=1,\n", |
918 | | - " # Z combination: concatenation preserves modality-specific signal\n", |
919 | | - " latent_mode=\"concatenation\",\n", |
920 | | - " # Per-modality flags\n", |
921 | | - " additive_background_modalities=[\"rna\"], # ambient correction for RNA only\n", |
922 | | - " feature_scaling_modalities=[\"rna\", \"atac\"] if use_feature_scaling else [], # per-covariate scaling\n", |
923 | | - " # Dispersion\n", |
924 | | - " dispersion=\"gene-batch\",\n", |
925 | | - " regularise_dispersion=True,\n", |
926 | | - " regularise_dispersion_prior=regularise_dispersion_prior,\n", |
927 | | - " dispersion_hyper_prior_alpha=dispersion_hyper_prior_alpha,\n", |
928 | | - " dispersion_hyper_prior_beta=dispersion_hyper_prior_beta,\n", |
929 | | - " # Background prior\n", |
930 | | - " additive_bg_prior_alpha=additive_bg_prior_alpha,\n", |
931 | | - " additive_bg_prior_beta=additive_bg_prior_beta,\n", |
932 | | - " regularise_background=regularise_background,\n", |
933 | | - " # Library centering\n", |
934 | | - " library_log_means_centering_sensitivity=library_log_means_centering_sensitivity,\n", |
935 | | - " library_log_vars_weight=library_log_vars_weight,\n", |
936 | | - " # Learnable modality scaling\n", |
937 | | - " learnable_modality_scaling=learnable_modality_scaling,\n", |
938 | | - " modality_scale_prior_concentration=modality_scale_prior_concentration,\n", |
939 | | - " # Pearson correlation metrics\n", |
940 | | - " compute_pearson=compute_pearson,\n", |
941 | | - " # Decoder regularisation and initialisation\n", |
942 | | - " decoder_weight_l2=decoder_weight_l2,\n", |
943 | | - " decoder_cov_weight_l2=decoder_cov_weight_l2,\n", |
944 | | - " init_decoder_bias=init_decoder_bias,\n", |
945 | | - " bg_init_gene_fraction=bg_init_gene_fraction,\n", |
946 | | - " decoder_bias_multiplier=_decoder_bias_multiplier,\n", |
947 | | - " # Residual library encoder\n", |
948 | | - " residual_library_encoder=residual_library_encoder,\n", |
949 | | - " library_obs_w_prior_rate=library_obs_w_prior_rate,\n", |
950 | | - ")\n", |
951 | | - "\n", |
952 | | - "print(model)\n", |
953 | | - "print(f\"\\nTotal latent dim: {model.module.total_latent_dim}\")\n", |
954 | | - "print(f\" RNA Z dim: {model.module.n_latent_dict['rna']}\")\n", |
955 | | - "print(f\" ATAC Z dim: {model.module.n_latent_dict['atac']}\")\n", |
956 | | - "print(f\"\\nRegion factors modalities: {model.module.feature_scaling_modalities}\")\n", |
957 | | - "for name in model.module.feature_scaling:\n", |
958 | | - " print(f\" {name} feature scaling shape: {model.module.feature_scaling[name].shape}\")\n", |
959 | | - "print(f\"\\nRegularise background: {regularise_background}\")\n", |
960 | | - "if regularise_background:\n", |
961 | | - " print(f\" Prior: Gamma({additive_bg_prior_alpha}, {additive_bg_prior_beta})\")\n", |
962 | | - " print(f\" Prior mean: {additive_bg_prior_alpha / additive_bg_prior_beta:.4f}\")\n", |
963 | | - "print(f\"Library centering sensitivity: {library_log_means_centering_sensitivity}\")\n", |
964 | | - "print(f\"Library log vars weight: {library_log_vars_weight}\")\n", |
965 | | - "print(f\"Decoder weight L2: {decoder_weight_l2}\")\n", |
966 | | - "print(f\"Decoder cov weight L2: {decoder_cov_weight_l2}\")\n", |
967 | | - "print(f\"Init decoder bias: {init_decoder_bias}\")\n", |
968 | | - "print(f\"BG init gene fraction: {bg_init_gene_fraction}\")\n", |
969 | | - "print(f\"Decoder bias multiplier: {_decoder_bias_multiplier}\")\n", |
970 | | - "print(f\"Dispersion prior: Gamma({dispersion_hyper_prior_alpha:.1f}, {dispersion_hyper_prior_beta:.2f})\")\n", |
971 | | - "print(f\"Residual library encoder: {residual_library_encoder}\")\n", |
972 | | - "print(f\"Library obs w prior rate: {library_obs_w_prior_rate}\")\n", |
973 | | - "print(f\"Learnable modality scaling: {learnable_modality_scaling}\")\n", |
974 | | - "if learnable_modality_scaling:\n", |
975 | | - " for name in model.module.modality_scale_init:\n", |
976 | | - " print(f\" {name} init: {model.module.modality_scale_init[name]:.3f}\")\n", |
977 | | - "print(f\"Compute Pearson: {compute_pearson}\")\n", |
978 | | - "\n", |
979 | | - "# Library encoder bias inspection\n", |
980 | | - "if hasattr(model.module, \"library_obs_w_mu\"):\n", |
981 | | - " import math\n", |
982 | | - " from torch.distributions import LogNormal\n", |
983 | | - " import torch\n", |
984 | | - "\n", |
985 | | - " for name in model.module.modality_names:\n", |
986 | | - " w_mu = model.module.library_obs_w_mu[name]\n", |
987 | | - " w_sigma = torch.exp(model.module.library_obs_w_log_sigma[name])\n", |
988 | | - " print(f\" library_obs_w_{name}: E[w]={LogNormal(w_mu, w_sigma).mean.item():.3f}\")\n", |
989 | | - "for name in model.module.modality_names:\n", |
990 | | - " bias_val = model.module.l_encoders[name].mean_encoder.bias.item()\n", |
991 | | - " print(f\" l_encoder[{name}].bias = {bias_val:.4f} (exp = {math.exp(bias_val):.4f})\")" |
992 | | - ] |
| 911 | + "source": "import math\n\nmodel = regularizedvi.RegularizedMultimodalVI(\n mdata,\n # Per-modality architecture sizes\n n_hidden={\"rna\": 512, \"atac\": 256},\n n_latent={\"rna\": 128, \"atac\": 64},\n n_layers=1,\n # Z combination: concatenation preserves modality-specific signal\n latent_mode=\"concatenation\",\n # Per-modality flags\n additive_background_modalities=[\"rna\"], # ambient correction for RNA only\n feature_scaling_modalities=[\"rna\", \"atac\"] if use_feature_scaling else [], # per-covariate scaling\n # Dispersion\n dispersion=\"gene-batch\",\n regularise_dispersion=True,\n regularise_dispersion_prior=regularise_dispersion_prior,\n dispersion_hyper_prior_alpha=dispersion_hyper_prior_alpha,\n dispersion_hyper_prior_beta=dispersion_hyper_prior_beta,\n # Background prior\n additive_bg_prior_alpha=additive_bg_prior_alpha,\n additive_bg_prior_beta=additive_bg_prior_beta,\n regularise_background=regularise_background,\n # Library centering\n library_log_means_centering_sensitivity=library_log_means_centering_sensitivity,\n library_log_vars_weight=library_log_vars_weight,\n # Learnable modality scaling\n learnable_modality_scaling=learnable_modality_scaling,\n modality_scale_prior_concentration=modality_scale_prior_concentration,\n # Pearson correlation metrics\n compute_pearson=compute_pearson,\n # Decoder regularisation and initialisation\n decoder_weight_l2=decoder_weight_l2,\n decoder_cov_weight_l2=decoder_cov_weight_l2,\n init_decoder_bias=init_decoder_bias,\n bg_init_gene_fraction=bg_init_gene_fraction,\n decoder_bias_multiplier=_decoder_bias_multiplier,\n # Residual library encoder\n residual_library_encoder=residual_library_encoder,\n library_obs_w_prior_rate=library_obs_w_prior_rate,\n)\n\nprint(model)\nprint(f\"\\nTotal latent dim: {model.module.total_latent_dim}\")\nprint(f\" RNA Z dim: {model.module.n_latent_dict['rna']}\")\nprint(f\" ATAC Z dim: {model.module.n_latent_dict['atac']}\")\nprint(f\"\\nRegion factors modalities: {model.module.feature_scaling_modalities}\")\nfor name in model.module.feature_scaling:\n print(f\" {name} feature scaling shape: {model.module.feature_scaling[name].shape}\")\nprint(f\"\\nRegularise background: {regularise_background}\")\nif regularise_background:\n print(f\" Prior: Gamma({additive_bg_prior_alpha}, {additive_bg_prior_beta})\")\n print(f\" Prior mean: {additive_bg_prior_alpha / additive_bg_prior_beta:.4f}\")\nprint(f\"Library centering sensitivity: {library_log_means_centering_sensitivity}\")\nprint(f\"Library log vars weight: {library_log_vars_weight}\")\nprint(f\"Decoder weight L2: {decoder_weight_l2}\")\nprint(f\"Decoder cov weight L2: {decoder_cov_weight_l2}\")\nprint(f\"Init decoder bias: {init_decoder_bias}\")\nprint(f\"BG init gene fraction: {bg_init_gene_fraction}\")\nprint(f\"Decoder bias multiplier: {_decoder_bias_multiplier}\")\nprint(f\"Dispersion prior: Gamma({dispersion_hyper_prior_alpha:.1f}, {dispersion_hyper_prior_beta:.2f})\")\nprint(f\"Residual library encoder: {residual_library_encoder}\")\nprint(f\"Library obs w prior rate: {library_obs_w_prior_rate}\")\nprint(f\"Learnable modality scaling: {learnable_modality_scaling}\")\nif learnable_modality_scaling:\n for name in model.module.modality_scale_init:\n print(f\" {name} init: {model.module.modality_scale_init[name]:.3f}\")\nprint(f\"Compute Pearson: {compute_pearson}\")\n\n# Library encoder bias inspection\nif hasattr(model.module, \"library_obs_w_mu\"):\n from torch.distributions import LogNormal\n import torch\n\n for name in model.module.modality_names:\n w_mu = model.module.library_obs_w_mu[name]\n w_sigma = torch.exp(model.module.library_obs_w_log_sigma[name])\n print(f\" library_obs_w_{name}: E[w]={LogNormal(w_mu, w_sigma).mean.item():.3f}\")\nfor name in model.module.modality_names:\n bias_val = model.module.l_encoders[name].mean_encoder.bias.item()\n print(f\" l_encoder[{name}].bias = {bias_val:.4f} (exp = {math.exp(bias_val):.4f})\")" |
993 | 912 | }, |
994 | 913 | { |
995 | 914 | "cell_type": "markdown", |
|
0 commit comments