Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions decipher/tools/_basis_decomposition/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pyro.infer
import pyro.optim
import torch
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.infer import SVI, Predictive, Trace_ELBO
from tqdm import tqdm

from decipher.tools._basis_decomposition.inference import get_inference_guide
Expand Down Expand Up @@ -65,13 +65,13 @@ def compute_basis_decomposition(
break

if plot_every_k_epochs > 0 and epoch % plot_every_k_epochs == 0:
from IPython.core import display
from IPython.display import clear_output, display

basis = model._last_basis.detach().numpy()
basis = model._last_basis.detach().cpu().numpy()
plt.figure(figsize=(5, 2.5))
_plot_basis(basis)
display.clear_output(wait=True)
display.display(plt.gcf())
clear_output(wait=True)
display(plt.gcf())
plt.close()

model.return_basis = False
Expand Down
19 changes: 13 additions & 6 deletions decipher/tools/_decipher/decipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def model(self, x, context=None):
if self.config.prior == "normal":
prior = dist.Normal(0, x.new_ones(self.config.dim_v)).to_event(1)
elif self.config.prior == "gamma":
prior = dist.Gamma(0.3, x.new_ones(self.config.dim_v) * 0.8).to_event(1)
prior = dist.Gamma(
0.3, x.new_ones(self.config.dim_v) * 0.8
).to_event(1)
else:
raise ValueError("Invalid prior, must be normal or gamma")
v = pyro.sample("v", prior)
Expand All @@ -131,7 +133,9 @@ def model(self, x, context=None):
self.theta + self._epsilon
)
# noinspection PyUnresolvedReferences
x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit)
x_dist = dist.NegativeBinomial(
total_count=self.theta + self._epsilon, logits=logit
)
pyro.sample("x", x_dist.to_event(1), obs=x)

def guide(self, x, context=None):
Expand All @@ -150,7 +154,10 @@ def guide(self, x, context=None):
with poutine.scale(scale=self.config.beta):
if self.config.prior == "gamma":
posterior_v = dist.Gamma(softplus(v_loc), v_scale).to_event(1)
elif self.config.prior == "normal" or self.config.prior == "student-normal":
elif (
self.config.prior == "normal"
or self.config.prior == "student-normal"
):
posterior_v = dist.Normal(v_loc, v_scale).to_event(1)
else:
raise ValueError("Invalid prior, must be normal or gamma")
Expand All @@ -173,17 +180,17 @@ def compute_v_z_numpy(self, x: np.array):
Decipher latent z of shape (n_cells, dim_z).
"""
if type(x) == np.ndarray:
x = torch.tensor(x, dtype=torch.float32)
x = torch.tensor(x, dtype=torch.float32, device=self.device)

x = torch.log1p(x)
z_loc, _ = self.encoder_x_to_z(x)
zx = torch.cat([z_loc, x], dim=-1)
v_loc, _ = self.encoder_zx_to_v(zx)
return v_loc.detach().numpy(), z_loc.detach().numpy()
return v_loc.detach().cpu().numpy(), z_loc.detach().cpu().numpy()

def impute_gene_expression_numpy(self, x):
if type(x) == np.ndarray:
x = torch.tensor(x, dtype=torch.float32)
x = torch.tensor(x, dtype=torch.float32, device=self.device)
z_loc, _, _, _ = self.guide(x)
mu = self.decoder_z_to_x(z_loc)
mu = softmax(mu, dim=-1)
Expand Down
21 changes: 13 additions & 8 deletions decipher/tools/decipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np
import pyro
import pyro.optim
import scanpy as sc
import scipy
import torch
from matplotlib import pyplot as plt
import pyro.optim
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from tqdm import tqdm
Expand All @@ -17,8 +17,8 @@
from decipher.tools._decipher.data import (
decipher_load_model,
decipher_save_model,
make_data_loader_from_adata,
get_dense_X,
make_data_loader_from_adata,
)
from decipher.tools.utils import EarlyStopping
from decipher.utils import DECIPHER_GLOBALS, GIFMaker, is_notebook, load_and_show_gif
Expand Down Expand Up @@ -58,7 +58,9 @@ def predictive_log_likelihood(decipher, dataloader, n_samples=5):
model_trace = poutine.trace(
poutine.replay(decipher.model, trace=guide_trace)
).get_trace(*xc)
total_log_prob += model_trace.log_prob_sum() - guide_trace.log_prob_sum()
total_log_prob += (
model_trace.log_prob_sum() - guide_trace.log_prob_sum()
)
log_weights.append(total_log_prob)

finally:
Expand Down Expand Up @@ -206,7 +208,8 @@ def decipher_train(

decipher.eval()
val_nll = (
-predictive_log_likelihood(decipher, dataloader_val, n_samples=5) / adata_val.shape[0]
-predictive_log_likelihood(decipher, dataloader_val, n_samples=5)
/ adata_val.shape[0]
)
val_losses.append(val_nll)
pbar.set_description(
Expand All @@ -223,7 +226,7 @@ def decipher_train(
plot_decipher_v(adata, basis="decipher_v", **plot_kwargs)
gif_maker.add_image(plt.gcf())
if is_notebook():
from IPython.core import display
from IPython.display import clear_output, display

display.clear_output(wait=True)
display.display(plt.gcf())
Expand All @@ -232,9 +235,9 @@ def decipher_train(
plt.close()

if is_notebook():
from IPython.core import display
from IPython.display import clear_output

display.clear_output()
clear_output()
pbar.display()

if early_stopping.has_stopped():
Expand Down Expand Up @@ -347,7 +350,9 @@ def score_rotation(r):
if auto_flip_decipher_z:
# flip each z to be correlated positively with the components
dim_z = adata.obsm["decipher_z"].shape[1]
z_v_corr = np.corrcoef(adata.obsm["decipher_z"], y=adata.obsm["decipher_v"], rowvar=False)
z_v_corr = np.corrcoef(
adata.obsm["decipher_z"], y=adata.obsm["decipher_v"], rowvar=False
)
z_sign_correction = np.sign(z_v_corr[:dim_z, dim_z:].sum(axis=1))
adata.obsm["decipher_z_not_rotated"] = adata.obsm["decipher_z"].copy()
adata.obsm["decipher_z"] = adata.obsm["decipher_z"] * z_sign_correction
Expand Down
4 changes: 2 additions & 2 deletions decipher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def add_image(self, fig):
"""
fig.set_dpi(self.dpi)
fig.canvas.draw()
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
image = np.frombuffer(fig.canvas.buffer_rgba(), dtype="uint8")
image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))
self.images.append(Image.fromarray(image))

def save_gif(self, filename):
Expand Down