diff --git a/.gitignore b/.gitignore index c72956b..f213930 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,4 @@ docs/wandb run_tutorial.s* logs/ +*.sbatch diff --git a/src/decima/data/dataset.py b/src/decima/data/dataset.py index f73df53..3fad18d 100644 --- a/src/decima/data/dataset.py +++ b/src/decima/data/dataset.py @@ -868,6 +868,8 @@ def __len__(self): def validate_allele_seq(self, gene, variant): seq = self.result.gene_sequence(gene, genome=self.genome) pos = variant.rel_pos + if variant.strand == "-": + pos = pos - len(variant.ref) + 1 ref_match = seq[pos : pos + len(variant.ref)] == variant.ref_tx alt_match = seq[pos : pos + len(variant.alt)] == variant.alt_tx return ref_match, alt_match @@ -889,6 +891,9 @@ def __getitem__(self, idx): variant = self.variants.iloc[seq_idx] rel_pos = variant.rel_pos + self.max_seq_shift + if variant.strand == "-": + rel_pos = rel_pos - len(variant.ref) + 1 + # by default cache values are nan if matched with reference genome # then it will be replaced with the predicted expression from cache. pred_expr = {model_name: torch.full((self.result.shape[0],), torch.nan) for model_name in self.model_names} diff --git a/src/decima/utils/dataframe.py b/src/decima/utils/dataframe.py index 37a1dad..663fd46 100644 --- a/src/decima/utils/dataframe.py +++ b/src/decima/utils/dataframe.py @@ -50,6 +50,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.writer.close() else: warnings.warn("NoDataFrameWrittenError: No dataframe was written to the parquet file.") + pd.DataFrame({}).to_parquet(self.output_path) self.first_chunk = True def write(self, chunk: pd.DataFrame) -> None: diff --git a/src/decima/vep/attributions.py b/src/decima/vep/attributions.py index ebbd531..81436d3 100644 --- a/src/decima/vep/attributions.py +++ b/src/decima/vep/attributions.py @@ -31,6 +31,7 @@ from decima.utils.io import read_vcf_chunks, VariantAttributionWriter from decima.core.result import DecimaResult from decima.data.dataset import VariantDataset +from decima.hub import load_decima_model from decima.interpret.attributer import DecimaAttributer from decima.model.metrics import WarningCounter from decima.vep.vep import _log_vep_warnings, _write_vep_warnings @@ -158,23 +159,22 @@ def variant_effect_attribution( f"Unsupported input type: {type(variants)}. Must be pd.DataFrame or str (path to .tsv or .vcf)." ) - result = DecimaResult.load(metadata_anndata) - + model = load_decima_model(model, device=device) + result = DecimaResult.load(metadata_anndata or model.name) tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks) - attributer = DecimaAttributer.load_decima_attributer( - model_name=model, + attributer = DecimaAttributer( + model=model, tasks=tasks, off_tasks=off_tasks, method=method, transform=transform, - device=device, ) warning_counter = WarningCounter() dataset = VariantDataset( variants, - metadata_anndata=metadata_anndata, + metadata_anndata=result, gene_col=gene_col, distance_type=distance_type, min_distance=min_distance, diff --git a/tests/test_vep.py b/tests/test_vep.py index 9109fb1..25511fa 100644 --- a/tests/test_vep.py +++ b/tests/test_vep.py @@ -82,6 +82,17 @@ def test_VariantDataset_overlap_genes(df_variant): }) df = VariantDataset.overlap_genes(df_variant, df_genes) +def test_VariantDataset_validate_allele_seq(): + df_variant = pd.DataFrame({ + "chrom": ["chr15"], + "pos": [44715509], + "ref": ["CC"], + "alt": ["TT"] + }) + dataset = VariantDataset(df_variant) + ref_match, _ = dataset.validate_allele_seq("SPG11", dataset.variants.iloc[1]) + assert ref_match + def test_VariantDataset(df_variant): dataset = VariantDataset(df_variant, model_name="v1_rep0")