Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[project]
[workspace]
name = "genvarloader"
channels = ["conda-forge", "bioconda"]
platforms = ["linux-64"]
Expand Down
121 changes: 83 additions & 38 deletions python/genvarloader/_dataset/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .._torch import TORCH_AVAILABLE, TorchDataset, get_dataloader
from .._types import AnnotatedHaps, Idx, StrIdx
from .._utils import lengths_to_offsets, normalize_contig_name
from ._indexing import DatasetIndexer
from ._indexing import DatasetIndexer, is_str_arr
from ._rag_variants import RaggedVariants
from ._reconstruct import Haps, HapsTracks, Ref, RefTracks, Tracks, TrackType
from ._reference import Reference
Expand Down Expand Up @@ -91,6 +91,9 @@ def open(
rng: int | np.random.Generator | None = False,
deterministic: bool = True,
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
region_names: str | None = None,
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ...
@overload
@staticmethod
Expand All @@ -101,6 +104,9 @@ def open(
rng: int | np.random.Generator | None = False,
deterministic: bool = True,
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
region_names: str | None = None,
) -> RaggedDataset[RaggedSeqs, MaybeRTRK]: ...
@staticmethod
def open(
Expand All @@ -112,6 +118,7 @@ def open(
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
region_names: str | None = None,
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]:
"""Open a dataset from a path. If no reference genome is provided, the dataset cannot yield sequences.
Will initialize the dataset such that it will return tracks and haplotypes (reference sequences if no genotypes) if possible.
Expand Down Expand Up @@ -151,9 +158,16 @@ def open(

# read input regions and generate index map
bed = pl.read_ipc(path / "input_regions.arrow")
if region_names is not None:
_region_names = bed[region_names].to_list()
else:
_region_names = None
r_idx_map = bed["r_idx_map"].to_numpy().astype(np.intp)
idxer = DatasetIndexer.from_region_and_sample_idxs(
r_idx_map, np.arange(len(samples)), samples
r_idxs=r_idx_map,
s_idxs=np.arange(len(samples)),
samples=samples,
regions=_region_names,
)
bed = bed.drop("r_idx_map")
sorted_bed = sp.bed.sort(bed)
Expand Down Expand Up @@ -527,7 +541,13 @@ class AnnotatedHaps:
"Dataset has no reference genome to reconstruct sequences from."
)
seqs = Ref(reference=ref)
return evolve(self, _recon=RefTracks(seqs=seqs, tracks=tracks))
return evolve(
self,
_recon=RefTracks(
seqs=seqs,
tracks=tracks, # type: ignore
),
)

case "haplotypes", Haps() as haps, _, Ref() | Haps():
return evolve(self, _recon=haps.to_kind(RaggedSeqs))
Expand All @@ -536,7 +556,13 @@ class AnnotatedHaps:
| RefTracks(tracks=tracks)
| HapsTracks(tracks=tracks)
):
return evolve(self, _recon=HapsTracks(haps.to_kind(RaggedSeqs), tracks))
return evolve(
self,
_recon=HapsTracks(
haps.to_kind(RaggedSeqs),
tracks, # type: ignore
),
)

case "annotated", Haps() as haps, _, Ref() | Haps():
return evolve(self, _recon=haps.to_kind(RaggedAnnotatedHaps))
Expand All @@ -546,7 +572,11 @@ class AnnotatedHaps:
| HapsTracks(tracks=tracks)
):
return evolve(
self, _recon=HapsTracks(haps.to_kind(RaggedAnnotatedHaps), tracks)
self,
_recon=HapsTracks(
haps.to_kind(RaggedAnnotatedHaps),
tracks, # type: ignore
),
)

case "variants", Haps() as haps, _, Ref() | Haps():
Expand All @@ -557,7 +587,11 @@ class AnnotatedHaps:
| HapsTracks(tracks=tracks)
):
return evolve(
self, _recon=HapsTracks(haps.to_kind(RaggedVariants), tracks)
self,
_recon=HapsTracks(
haps.to_kind(RaggedVariants),
tracks, # type: ignore
),
)

case k, s, t, r:
Expand Down Expand Up @@ -613,15 +647,24 @@ def with_tracks(
"Can't set dataset to return tracks because it has none to begin with."
)
case t, _, tr, (Ref() as seqs) | RefTracks(seqs=seqs):
tr = tr.with_tracks(t).to_kind(_kind)
tr = tr.with_tracks(t).to_kind(
_kind, # type: ignore
)
recon = RefTracks(seqs=seqs, tracks=tr)
return evolve(self, _tracks=tr, _recon=recon)
case t, _, tr, (Haps() as seqs) | HapsTracks(haps=seqs):
tr = tr.with_tracks(t).to_kind(_kind)
recon = HapsTracks(haps=seqs, tracks=tr)
tr = tr.with_tracks(t).to_kind(
_kind, # type: ignore
)
recon = HapsTracks(
haps=seqs, # type: ignore
tracks=tr,
)
return evolve(self, _tracks=tr, _recon=recon)
case t, _, tr, Tracks():
tr = tr.with_tracks(t).to_kind(_kind)
tr = tr.with_tracks(t).to_kind(
_kind, # type: ignore
)
return evolve(self, _tracks=tr, _recon=tr)
case k, s, t, r:
assert_never(k), assert_never(s), assert_never(t), assert_never(r)
Expand Down Expand Up @@ -851,7 +894,7 @@ def __repr__(self) -> str:

def subset_to(
self,
regions: Idx | None = None,
regions: StrIdx | None = None,
samples: StrIdx | None = None,
) -> Self:
"""Subset the dataset to specific regions and/or samples by index or a boolean mask. If regions or samples
Expand Down Expand Up @@ -910,13 +953,10 @@ def subset_to(
if regions is None and samples is None:
return self

if regions is not None:
if isinstance(regions, pl.Series):
regions = regions.to_numpy()
if np.issubdtype(regions.dtype, np.bool_):
regions = np.nonzero(regions)[0]
elif not np.issubdtype(regions.dtype, np.integer):
raise ValueError("`regions` must be index-like or a boolean mask.")
if is_str_arr(regions) and self._idxer.r2i_map is None:
raise ValueError(
"Cannot subset to regions by name because no region name was set."
)

idxer = self._idxer.subset_to(regions=regions, samples=samples)

Expand Down Expand Up @@ -981,7 +1021,7 @@ def haplotype_lengths(
def n_variants(
self,
regions: Idx | None = None,
samples: Idx | str | Sequence[str] | None = None,
samples: StrIdx | None = None,
) -> NDArray[np.int32]:
"""The number of variants in the dataset for specified regions and samples.

Expand Down Expand Up @@ -1026,7 +1066,7 @@ def n_variants(
def n_intervals(
self,
regions: Idx | None = None,
samples: Idx | str | Sequence[str] | None = None,
samples: StrIdx | None = None,
) -> NDArray[np.int32]:
"""The number of intervals in the dataset for specified regions and samples.

Expand Down Expand Up @@ -1315,7 +1355,7 @@ def to_dataloader(
)

def __getitem__(
self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]]
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> (
Ragged[np.bytes_ | np.float32]
| RaggedAnnotatedHaps
Expand All @@ -1333,6 +1373,11 @@ def __getitem__(
...,
]
):
if is_str_arr(idx) and self._idxer.r2i_map is None:
raise ValueError(
"Cannot query regions by name because no region name was set."
)

# (b)
ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
r_idx, _ = np.unravel_index(ds_idx, self.full_shape)
Expand Down Expand Up @@ -1579,17 +1624,17 @@ def with_tracks(
@overload
def __getitem__(
self: ArrayDataset[SEQ, None],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> SEQ: ...
@overload
def __getitem__(
self: ArrayDataset[None, NDArray[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> NDArray[np.float32]: ...
@overload
def __getitem__(
self: ArrayDataset[SEQ, NDArray[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[SEQ, NDArray[np.float32]]: ...
@overload
def __getitem__(
Expand All @@ -1614,15 +1659,15 @@ def __getitem__(
@overload
def __getitem__(
self: ArrayDataset[MaybeSEQ, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals | tuple[SEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: ArrayDataset[MaybeSEQ, MaybeTRK],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> SEQ | NDArray[np.float32] | tuple[SEQ, NDArray[np.float32]]: ...
def __getitem__(
self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]]
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> SEQ | TRK | tuple[SEQ, TRK]:
return super().__getitem__(idx) # type: ignore

Expand Down Expand Up @@ -1729,54 +1774,54 @@ def with_tracks(
@overload
def __getitem__(
self: RaggedDataset[None, None],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> NoReturn: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, None],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ: ...
@overload
def __getitem__(
self: RaggedDataset[None, Ragged[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> Ragged[np.float32]: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, Ragged[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[None, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[RSEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, MaybeRTRK],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ | tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, Ragged[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals | tuple[RSEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, MaybeRTRK],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ | Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ...
def __getitem__(
self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]]
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> RSEQ | RTRK | tuple[RSEQ, RTRK]:
return super().__getitem__(idx) # type: ignore
Loading