Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19,530 changes: 7,408 additions & 12,122 deletions pixi.lock

Large diffs are not rendered by default.

43 changes: 22 additions & 21 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "genvarloader"
channels = ["conda-forge", "bioconda"]
platforms = ["linux-64"]

[activation.env]
LD_LIBRARY_PATH = "$CONDA_PREFIX/lib"

[environments]
dev = { features = ["pytorch-cpu", "basenji2", "py310"] }
docs = { features = ["docs", "pytorch-cpu", "basenji2", "py312"] }
Expand All @@ -14,16 +17,31 @@ no-torch = { features = ["py310"] }
demo = { features = ["demo", "py310"] }

[dependencies]
uv = "*"
samtools = "*"
bcftools = "*"
plink2 = "*"
ruff = "*"
pre-commit = "*"
commitizen = "*"
pyarrow = "<22"
maturin = ">=1.6,<2"
typing-extensions = ">=4.14"

[pypi-dependencies]
genvarloader = { path = ".", editable = true }
hirola = "==0.3"
seqpro = "==0.9.0"
genoray = "==1.0.1"
numba = ">=0.58.1"
polars = "==1.26.0"
polars = "==1.30.0"
loguru = "*"
attrs = "*"
natsort = "*"
cyvcf2 = "*"
pgenlib = "*"
pandera = "*"
pysam = "*"
pyarrow = "*"
pyranges = "*"
more-itertools = "*"
tqdm = "*"
Expand All @@ -33,32 +51,18 @@ tbb = "*"
joblib = "*"
pooch = "*"
awkward = "*"
maturin = ">=1.6,<2"
pytest = "*"
memray = "*"
py-spy = "*"
icecream = "*"
pydantic = ">=2,<3"
pytest-cases = "*"
pytest-cov = "*"
ruff = "*"
pre-commit = "*"
pytest-benchmark = "*"
hypothesis = "*"
filelock = "*"
patchelf = "*"
commitizen = "*"
typer = "*"
uv = "*"
samtools = "*"
bcftools = "*"
plink2 = "*"

[pypi-dependencies]
# genvarloader = { path = ".", editable = true }
hirola = "==0.3"
seqpro = "==0.8.2"
genoray = "==0.16.0"
pydantic = ">=2,<3"

[feature.docs.dependencies]
sphinx = ">=7.4.7"
Expand Down Expand Up @@ -92,11 +96,8 @@ python = "3.11.*"
python = "3.12.*"

[tasks]
install = "uv pip install -e ."
pre-commit = "pre-commit install --hook-type commit-msg --hook-type pre-push"
gen = { cmd = "python tests/data/generate_ground_truth.py", depends-on = [
"install",
] }
gen = { cmd = "python tests/data/generate_ground_truth.py" }
test = { cmd = "pytest tests && cargo test --release", depends-on = ["gen"] }

[feature.docs.tasks]
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"loguru",
"attrs",
"natsort",
"polars>=1.26",
"polars>=1.30",
"cyvcf2",
"pandera",
"pysam",
Expand All @@ -30,8 +30,8 @@ dependencies = [
"pooch",
"awkward",
"hirola>=0.3,<0.4",
"seqpro>=0.8.2",
"genoray>=0.16.0",
"seqpro>=0.9",
"genoray>=1.0.1,<2",
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions python/genvarloader/_dataset/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,9 +1453,9 @@ def _rc(
) -> Ragged | RaggedAnnotatedHaps | RaggedVariants | RaggedIntervals:
if isinstance(rag, Ragged):
if is_rag_dtype(rag, np.bytes_):
rag = Ragged(ak.where(to_rc, reverse_complement(rag), rag))
rag = Ragged(ak.where(to_rc, reverse_complement(rag), rag)) # type: ignore
else:
rag = Ragged(ak.where(to_rc, rag[..., ::-1], rag))
rag = Ragged(ak.where(to_rc, rag[..., ::-1], rag)) # type: ignore
elif isinstance(rag, RaggedAnnotatedHaps):
rag.haps = self._rc(rag.haps, to_rc)
rag.var_idxs = self._rc(rag.var_idxs, to_rc)
Expand Down
15 changes: 7 additions & 8 deletions python/genvarloader/_dataset/_rag_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
NumpyArray,
RegularArray,
)
from genoray._svar import DOSAGE_TYPE, POS_TYPE, V_IDX_TYPE
from genoray._types import DOSAGE_TYPE, POS_TYPE, V_IDX_TYPE
from numpy.typing import NDArray
from seqpro.rag import OFFSET_TYPE, Ragged, lengths_to_offsets
from typing_extensions import Self
Expand All @@ -24,6 +24,7 @@
if TORCH_AVAILABLE or TYPE_CHECKING:
import torch
from torch.nested import nested_tensor_from_jagged as nt_jag
from torch.nested._internal.nested_tensor import NestedTensor


class RaggedVariant(ak.Record):
Expand Down Expand Up @@ -234,7 +235,7 @@ def to_nested_tensor_batch(
tokenizer: Literal["seqpro"]
| Callable[[NDArray[np.bytes_]], NDArray[np.integer]]
| None = None,
) -> dict[str, Any]:
) -> dict[str, NestedTensor | int]:
"""Convert a RaggedVariants object to a dictionary of nested tensors. Will flatten across
the ploidy dimension for attributes ILEN, starts, and dosages such that their shapes are (batch * ploidy, ~variants).
For the alternative alleles, will flatten across both the ploidy and variant dimensions such that the
Expand Down Expand Up @@ -278,13 +279,11 @@ def to_nested_tensor_batch(
)
batch["var_maxlen"] = int(np.diff(arr.offsets).max())
batch[field] = nt_jag(data, variant_offsets)
else:
elif field in {"ref", "alt"}:
data, offsets, max_alen = _alleles_to_nested_tensor(arr, tokenizer)
if field == "alt":
batch["alt_maxlen"] = max_alen
elif field == "ref":
batch["ref_maxlen"] = max_alen
batch[field] = nt_jag(data, offsets).to(device)
data = data.to(device)
batch[f"{field}_maxlen"] = max_alen
batch[field] = nt_jag(data, offsets)

return batch

Expand Down
13 changes: 4 additions & 9 deletions python/genvarloader/_dataset/_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
from awkward.contents import ListOffsetArray, NumpyArray, RegularArray
from awkward.index import Index
from einops import repeat
from genoray._svar import (
DOSAGE_TYPE,
POS_TYPE,
V_IDX_TYPE,
SparseDosages,
SparseGenotypes,
)
from genoray._svar import SparseDosages, SparseGenotypes
from genoray._types import DOSAGE_TYPE, POS_TYPE, V_IDX_TYPE
from loguru import logger
from numpy.typing import NDArray
from packaging.version import Version
Expand Down Expand Up @@ -489,7 +484,7 @@ def _get_variants(
{
k: self._get_info(genos, k)
for k in self.var_fields
if k not in {"alt", "ilen", "start", "ref"}
if k not in {"alt", "start", "ref", "ilen", "dosage"}
}
)

Expand Down Expand Up @@ -832,7 +827,7 @@ def _call_intervals(self, idx: NDArray[np.integer]) -> RaggedIntervals:
starts = ak.concatenate(out_starts, axis=1)
ends = ak.concatenate(out_ends, axis=1)
values = ak.concatenate(out_values, axis=1)
return RaggedIntervals(starts, ends, values)
return RaggedIntervals(starts, ends, values) # type: ignore

def write_transformed_track(
self,
Expand Down
11 changes: 5 additions & 6 deletions python/genvarloader/_dataset/_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,21 +366,20 @@ def __getitem__(self, idx: Idx) -> T:

to_rc = regions[:, 3] == -1
if to_rc.any():
rc_ref = reverse_complement(cast(Ragged[np.bytes_], ref[to_rc]))
ref = ak.where(to_rc, rc_ref, ref)
ref = ak.where(to_rc, reverse_complement(ref), ref)

if out_reshape is not None:
ref = ref.reshape(out_reshape)
ref = ref.reshape(out_reshape) # type: ignore

if self.output_length == "ragged":
out = ref
elif self.output_length == "variable":
out = to_padded(ref, pad_value=self.reference.pad_char)
out = to_padded(ref, pad_value=self.reference.pad_char) # type: ignore
else:
out = ref.to_numpy()
out = ref.to_numpy() # type: ignore

if squeeze:
out = out.squeeze(0)
out = out.squeeze(0) # type: ignore

return cast(T, out)

Expand Down
6 changes: 4 additions & 2 deletions python/genvarloader/_dataset/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ def bed_to_regions(bed: pl.DataFrame, contigs: Sequence[str]) -> NDArray[np.int3
pl.col("chromStart", "chromEnd").cast(pl.Int32),
]

if "strand" in bed:
if bed.schema.get("strand", None) == pl.Utf8:
cols.append(
pl.col("strand").replace_strict({"+": 1, "-": -1}, return_dtype=pl.Int32)
)
else:
elif "strand" not in bed.schema:
cols.append(pl.lit(1).cast(pl.Int32).alias("strand"))
else:
cols.append(pl.col("strand"))

return bed.select(cols).to_numpy()

Expand Down
34 changes: 17 additions & 17 deletions python/genvarloader/_dataset/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
import polars as pl
import seqpro as sp
from genoray import PGEN, VCF, Reader, SparseVar
from genoray._svar import V_IDX_TYPE, SparseGenotypes, dense2sparse
from genoray._svar import SparseGenotypes, dense2sparse
from genoray._types import V_IDX_TYPE
from genoray._utils import parse_memory
from loguru import logger
from more_itertools import mark_ends
from natsort import natsorted
from numpy.typing import NDArray
from packaging.version import Version
from pydantic import BaseModel, BeforeValidator, PlainSerializer, WithJsonSchema
from seqpro.rag import OFFSET_TYPE
from tqdm.auto import tqdm

from .._bigwig import BigWigs
from .._ragged import INTERVAL_DTYPE
from .._utils import lengths_to_offsets, normalize_contig_name
from .._variants._utils import path_is_pgen, path_is_vcf
from ._utils import splits_sum_le_value
from ._utils import bed_to_regions, splits_sum_le_value


class Metadata(BaseModel, arbitrary_types_allowed=True):
Expand Down Expand Up @@ -237,12 +237,7 @@ def _prep_bed(


def _write_regions(path: Path, bed: pl.DataFrame, contigs: list[str]):
with pl.StringCache():
pl.Series(contigs, dtype=pl.Categorical)
regions = bed.with_columns(
pl.col("chrom").cast(pl.Categorical).to_physical()
).with_columns(pl.all().cast(pl.Int32))
regions = regions.to_numpy()
regions = bed_to_regions(bed, contigs)
np.save(path / "regions.npy", regions)


Expand Down Expand Up @@ -318,7 +313,10 @@ def _write_from_vcf(path: Path, bed: pl.DataFrame, vcf: VCF, max_mem: int):
if is_last:
max_ends.append(chunk_end)

var_idxs = ak.flatten(ak.concatenate(ls_sparse, -1), None).to_numpy()
var_idxs = ak.flatten(
ak.concatenate(ls_sparse, -1),
None, # type: ignore
).to_numpy()
# (s p)
lengths = np.stack([a.lengths for a in ls_sparse], 0).sum(0)

Expand Down Expand Up @@ -398,7 +396,10 @@ def _write_from_pgen(path: Path, bed: pl.DataFrame, pgen: PGEN, max_mem: int):
if is_last:
max_ends.append(chunk_end)

var_idxs = ak.flatten(ak.concatenate(ls_sparse, -1), None).to_numpy()
var_idxs = ak.flatten(
ak.concatenate(ls_sparse, -1),
None, # type: ignore
).to_numpy()
# (s p)
lengths = np.stack([a.lengths for a in ls_sparse], 0).sum(0)

Expand Down Expand Up @@ -458,7 +459,9 @@ def _write_from_svar(
with open(out_dir / "svar_meta.json", "w") as f:
json.dump({"shape": offsets.shape, "dtype": offsets.dtype.str}, f)

v_ends = svar.granges.End
v_ends = svar.var_table.select(
end=pl.col("POS") - pl.col("ILEN").list.first().clip(upper_bound=0)
)["end"].to_numpy()
max_ends = np.empty(bed.height, np.int32)
contig_offset = 0
pbar = tqdm(total=bed.height, unit=" region")
Expand All @@ -477,10 +480,7 @@ def _write_from_svar(
c, df["chromStart"], df["chromEnd"], samples=samples, out=out
)

if (
first_no_variant_warning
and (out == np.iinfo(OFFSET_TYPE).max).all((1, 2, 3)).any()
):
if first_no_variant_warning and (out == 0).all((1, 2, 3)).any():
first_no_variant_warning = False
logger.warning(
"Some regions have no variants for any sample. This could be expected depending on the region lengths"
Expand Down Expand Up @@ -511,7 +511,7 @@ def _write_from_svar(
pbar.close()
offsets.flush()

(out_dir / "link.svar").symlink_to(svar.path, True)
(out_dir / "link.svar").symlink_to(svar.path.resolve(), target_is_directory=True)

return bed.with_columns(chromEnd=pl.Series(max_ends))

Expand Down
3 changes: 3 additions & 0 deletions tests/dataset/test_ds_haps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def dataset_vcf():
gvl.Dataset.open(data_dir / "phased_dataset.vcf.gvl", ref, rc_neg=False)
.with_len("ragged")
.with_seqs("haplotypes")
.with_tracks(False)
)
return ds

Expand All @@ -27,6 +28,7 @@ def dataset_pgen():
gvl.Dataset.open(data_dir / "phased_dataset.pgen.gvl", ref, rc_neg=False)
.with_len("ragged")
.with_seqs("haplotypes")
.with_tracks(False)
)
return ds

Expand All @@ -36,6 +38,7 @@ def dataset_svar():
gvl.Dataset.open(data_dir / "phased_dataset.svar.gvl", ref, rc_neg=False)
.with_len("ragged")
.with_seqs("haplotypes")
.with_tracks(False)
)
return ds

Expand Down