diff --git a/decipher/tools/_basis_decomposition/run.py b/decipher/tools/_basis_decomposition/run.py index ee887a1..cc1b291 100644 --- a/decipher/tools/_basis_decomposition/run.py +++ b/decipher/tools/_basis_decomposition/run.py @@ -4,7 +4,7 @@ import pyro.infer import pyro.optim import torch -from pyro.infer import Predictive, SVI, Trace_ELBO +from pyro.infer import SVI, Predictive, Trace_ELBO from tqdm import tqdm from decipher.tools._basis_decomposition.inference import get_inference_guide @@ -65,13 +65,13 @@ def compute_basis_decomposition( break if plot_every_k_epochs > 0 and epoch % plot_every_k_epochs == 0: - from IPython.core import display + from IPython.display import clear_output, display - basis = model._last_basis.detach().numpy() + basis = model._last_basis.detach().cpu().numpy() plt.figure(figsize=(5, 2.5)) _plot_basis(basis) - display.clear_output(wait=True) - display.display(plt.gcf()) + clear_output(wait=True) + display(plt.gcf()) plt.close() model.return_basis = False diff --git a/decipher/tools/_decipher/decipher.py b/decipher/tools/_decipher/decipher.py index a4b8683..c37ad0b 100644 --- a/decipher/tools/_decipher/decipher.py +++ b/decipher/tools/_decipher/decipher.py @@ -112,7 +112,9 @@ def model(self, x, context=None): if self.config.prior == "normal": prior = dist.Normal(0, x.new_ones(self.config.dim_v)).to_event(1) elif self.config.prior == "gamma": - prior = dist.Gamma(0.3, x.new_ones(self.config.dim_v) * 0.8).to_event(1) + prior = dist.Gamma( + 0.3, x.new_ones(self.config.dim_v) * 0.8 + ).to_event(1) else: raise ValueError("Invalid prior, must be normal or gamma") v = pyro.sample("v", prior) @@ -131,7 +133,9 @@ def model(self, x, context=None): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) def guide(self, x, context=None): @@ -150,7 +154,10 @@ def guide(self, x, context=None): with poutine.scale(scale=self.config.beta): if self.config.prior == "gamma": posterior_v = dist.Gamma(softplus(v_loc), v_scale).to_event(1) - elif self.config.prior == "normal" or self.config.prior == "student-normal": + elif ( + self.config.prior == "normal" + or self.config.prior == "student-normal" + ): posterior_v = dist.Normal(v_loc, v_scale).to_event(1) else: raise ValueError("Invalid prior, must be normal or gamma") @@ -173,17 +180,17 @@ def compute_v_z_numpy(self, x: np.array): Decipher latent z of shape (n_cells, dim_z). """ if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) + x = torch.tensor(x, dtype=torch.float32, device=self.device) x = torch.log1p(x) z_loc, _ = self.encoder_x_to_z(x) zx = torch.cat([z_loc, x], dim=-1) v_loc, _ = self.encoder_zx_to_v(zx) - return v_loc.detach().numpy(), z_loc.detach().numpy() + return v_loc.detach().cpu().numpy(), z_loc.detach().cpu().numpy() def impute_gene_expression_numpy(self, x): if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) + x = torch.tensor(x, dtype=torch.float32, device=self.device) z_loc, _, _, _ = self.guide(x) mu = self.decoder_z_to_x(z_loc) mu = softmax(mu, dim=-1) diff --git a/decipher/tools/decipher.py b/decipher/tools/decipher.py index a9125e6..a979277 100644 --- a/decipher/tools/decipher.py +++ b/decipher/tools/decipher.py @@ -3,11 +3,11 @@ import numpy as np import pyro +import pyro.optim import scanpy as sc import scipy import torch from matplotlib import pyplot as plt -import pyro.optim from pyro import poutine from pyro.infer import SVI, Trace_ELBO from tqdm import tqdm @@ -17,8 +17,8 @@ from decipher.tools._decipher.data import ( decipher_load_model, decipher_save_model, - make_data_loader_from_adata, get_dense_X, + make_data_loader_from_adata, ) from decipher.tools.utils import EarlyStopping from decipher.utils import DECIPHER_GLOBALS, GIFMaker, is_notebook, load_and_show_gif @@ -58,7 +58,9 @@ def predictive_log_likelihood(decipher, dataloader, n_samples=5): model_trace = poutine.trace( poutine.replay(decipher.model, trace=guide_trace) ).get_trace(*xc) - total_log_prob += model_trace.log_prob_sum() - guide_trace.log_prob_sum() + total_log_prob += ( + model_trace.log_prob_sum() - guide_trace.log_prob_sum() + ) log_weights.append(total_log_prob) finally: @@ -206,7 +208,8 @@ def decipher_train( decipher.eval() val_nll = ( - -predictive_log_likelihood(decipher, dataloader_val, n_samples=5) / adata_val.shape[0] + -predictive_log_likelihood(decipher, dataloader_val, n_samples=5) + / adata_val.shape[0] ) val_losses.append(val_nll) pbar.set_description( @@ -223,7 +226,7 @@ def decipher_train( plot_decipher_v(adata, basis="decipher_v", **plot_kwargs) gif_maker.add_image(plt.gcf()) if is_notebook(): - from IPython.core import display + from IPython.display import clear_output, display display.clear_output(wait=True) display.display(plt.gcf()) @@ -232,9 +235,9 @@ def decipher_train( plt.close() if is_notebook(): - from IPython.core import display + from IPython.display import clear_output - display.clear_output() + clear_output() pbar.display() if early_stopping.has_stopped(): @@ -347,7 +350,9 @@ def score_rotation(r): if auto_flip_decipher_z: # flip each z to be correlated positively with the components dim_z = adata.obsm["decipher_z"].shape[1] - z_v_corr = np.corrcoef(adata.obsm["decipher_z"], y=adata.obsm["decipher_v"], rowvar=False) + z_v_corr = np.corrcoef( + adata.obsm["decipher_z"], y=adata.obsm["decipher_v"], rowvar=False + ) z_sign_correction = np.sign(z_v_corr[:dim_z, dim_z:].sum(axis=1)) adata.obsm["decipher_z_not_rotated"] = adata.obsm["decipher_z"].copy() adata.obsm["decipher_z"] = adata.obsm["decipher_z"] * z_sign_correction diff --git a/decipher/utils.py b/decipher/utils.py index e75e705..3e637cc 100644 --- a/decipher/utils.py +++ b/decipher/utils.py @@ -57,8 +57,8 @@ def add_image(self, fig): """ fig.set_dpi(self.dpi) fig.canvas.draw() - image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + image = np.frombuffer(fig.canvas.buffer_rgba(), dtype="uint8") + image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,)) self.images.append(Image.fromarray(image)) def save_gif(self, filename):