Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ info:
model_name: "small"

arguments:
- name: "--model_name"
- name: --model_name
type: "string"
description: Which model to use. Not used if --model is provided.
choices: ["large", "v2-medium", "small"]
Expand All @@ -48,18 +48,18 @@ arguments:
type: file
description: Path to the scPRINT model.
required: false
- name: max_len
- name: --max_len
type: integer
description: Maximum number of genes to consider.
default: 5000
- name: batch_size
default: 12000
- name: --batch_size
type: integer
description: Batch size for processing.
default: 32
- name: predict_depth_mult
- name: --predict_depth_mult
type: double
description: Multiplier for prediction depth.
default: 5.0
default: 1.0

resources:
- type: python_script
Expand All @@ -71,9 +71,7 @@ engines:
setup:
- type: python
pip:
- git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b
- gseapy==1.1.2
- git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa
- scprint==2.3.5
- type: docker
run: lamin init --storage ./main --name main --schema bionty
- type: docker
Expand All @@ -85,6 +83,7 @@ engines:

runners:
- type: executable
# docker_run_args: --gpus all
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
52 changes: 38 additions & 14 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import torch
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scdataloader.utils import load_genes
from scprint import scPrint
from scprint.tasks import Denoiser

## VIASH START
par = {
"input_train": "resources_test/task_batch_integration/cxg_immune_cell_atlas/train.h5ad",
"output": "output.h5ad",
"model_name": "large",
"model_name": "v2-medium",
"model": None,
"predict_depth_mult": 5.0,
"max_len": 5000,
"predict_depth_mult": 1.0,
"max_len": 12000,
"batch_size": 32,
}
meta = {"name": "scprint"}
Expand Down Expand Up @@ -61,39 +62,60 @@
)
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)

print("\n>>> Denoising data...", flush=True)
if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16-mixed"
dtype = torch.float16
transformer = "flash"
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
transformer = "normal"

m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
# make sure that you check if you have a GPU with flashattention or not (see README)
try:
m = torch.load(model_checkpoint_file)
# if not use this instead since the model weights are by default mapped to GPU types
except RuntimeError:
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))

# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
if "prenorm" in m["hyper_parameters"]:
m["hyper_parameters"].pop("prenorm")
torch.save(m, model_checkpoint_file)
if "label_counts" in m["hyper_parameters"]:
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
classes=m["hyper_parameters"]["label_counts"],
transformer=transformer,
)
else:
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
)
del m
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
missing = set(model.genes) - set(load_genes(model.organisms).index)
if len(missing) > 0:
print(
"Warning: some genes missmatch exist between model and ontology: solving...",
)
model._rm_genes(missing)

# again if not on GPU you need to convert the model to float64
if not torch.cuda.is_available():
model = model.to(torch.float32)

# you can perform your inference on float16 if you have a GPU, otherwise use float64
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")


n_cores = min(len(os.sched_getaffinity(0)), 24)
print(f"Using {n_cores} worker cores")
denoiser = Denoiser(
num_workers=n_cores,
precision=precision,
max_cells=adata.n_obs + 1000,
max_len=par["max_len"],
batch_size=par["batch_size"],
Expand All @@ -103,6 +125,8 @@
dtype=dtype,
how="most var",
)

print("\n>>> Denoising data...", flush=True)
_, idxs, output = denoiser(model, adata)
print(f"Predicted expression dimensions: {output.shape}")

Expand Down
23 changes: 11 additions & 12 deletions src/metrics/mse/script.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import anndata as ad
import scanpy as sc
import sklearn.metrics
import scprep
import sklearn.metrics

## VIASH START
par = {
'input_test': 'resources_test/task_denoising/cxg_immune_cell_atlas/test.h5ad',
'input_prediction': 'resources_test/task_denoising/cxg_immune_cell_atlas/denoised.h5ad',
'output': 'output_mse.h5ad'
}
meta = {
'name': 'mse'
"input_test": "resources_test/task_denoising/cxg_immune_cell_atlas/test.h5ad",
"input_prediction": "resources_test/task_denoising/cxg_immune_cell_atlas/denoised.h5ad",
"output": "output_mse.h5ad",
}
meta = {"name": "mse"}
## VIASH END

print("Load data", flush=True)
input_denoised = ad.read_h5ad(par['input_prediction'])
input_test = ad.read_h5ad(par['input_test'])
input_denoised = ad.read_h5ad(par["input_prediction"])
input_test = ad.read_h5ad(par["input_test"])

test_data = ad.AnnData(X=input_test.layers["counts"])
denoised_data = ad.AnnData(X=input_denoised.layers["denoised"])
Expand All @@ -39,12 +37,13 @@

print("Store mse value", flush=True)
output = ad.AnnData(
uns={ key: val for key, val in input_test.uns.items() },
uns={key: val for key, val in input_test.uns.items()},
)

output.uns["method_id"] = input_denoised.uns["method_id"]
output.uns["metric_ids"] = meta['name']
output.uns["metric_ids"] = meta["name"]
output.uns["metric_values"] = error

print("Write adata to file", flush=True)
output.write_h5ad(par['output'], compression="gzip")
print(output.uns)
output.write_h5ad(par["output"], compression="gzip")