diff --git a/pixi.toml b/pixi.toml index 6ea9a80..d45d4e0 100644 --- a/pixi.toml +++ b/pixi.toml @@ -1,4 +1,4 @@ -[project] +[workspace] name = "genvarloader" channels = ["conda-forge", "bioconda"] platforms = ["linux-64"] diff --git a/python/genvarloader/_dataset/_impl.py b/python/genvarloader/_dataset/_impl.py index 92ebc16..137b42c 100644 --- a/python/genvarloader/_dataset/_impl.py +++ b/python/genvarloader/_dataset/_impl.py @@ -28,7 +28,7 @@ from .._torch import TORCH_AVAILABLE, TorchDataset, get_dataloader from .._types import AnnotatedHaps, Idx, StrIdx from .._utils import lengths_to_offsets, normalize_contig_name -from ._indexing import DatasetIndexer +from ._indexing import DatasetIndexer, is_str_arr from ._rag_variants import RaggedVariants from ._reconstruct import Haps, HapsTracks, Ref, RefTracks, Tracks, TrackType from ._reference import Reference @@ -91,6 +91,9 @@ def open( rng: int | np.random.Generator | None = False, deterministic: bool = True, rc_neg: bool = True, + min_af: float | None = None, + max_af: float | None = None, + region_names: str | None = None, ) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ... @overload @staticmethod @@ -101,6 +104,9 @@ def open( rng: int | np.random.Generator | None = False, deterministic: bool = True, rc_neg: bool = True, + min_af: float | None = None, + max_af: float | None = None, + region_names: str | None = None, ) -> RaggedDataset[RaggedSeqs, MaybeRTRK]: ... @staticmethod def open( @@ -112,6 +118,7 @@ def open( rc_neg: bool = True, min_af: float | None = None, max_af: float | None = None, + region_names: str | None = None, ) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: """Open a dataset from a path. If no reference genome is provided, the dataset cannot yield sequences. Will initialize the dataset such that it will return tracks and haplotypes (reference sequences if no genotypes) if possible. @@ -151,9 +158,16 @@ def open( # read input regions and generate index map bed = pl.read_ipc(path / "input_regions.arrow") + if region_names is not None: + _region_names = bed[region_names].to_list() + else: + _region_names = None r_idx_map = bed["r_idx_map"].to_numpy().astype(np.intp) idxer = DatasetIndexer.from_region_and_sample_idxs( - r_idx_map, np.arange(len(samples)), samples + r_idxs=r_idx_map, + s_idxs=np.arange(len(samples)), + samples=samples, + regions=_region_names, ) bed = bed.drop("r_idx_map") sorted_bed = sp.bed.sort(bed) @@ -527,7 +541,13 @@ class AnnotatedHaps: "Dataset has no reference genome to reconstruct sequences from." ) seqs = Ref(reference=ref) - return evolve(self, _recon=RefTracks(seqs=seqs, tracks=tracks)) + return evolve( + self, + _recon=RefTracks( + seqs=seqs, + tracks=tracks, # type: ignore + ), + ) case "haplotypes", Haps() as haps, _, Ref() | Haps(): return evolve(self, _recon=haps.to_kind(RaggedSeqs)) @@ -536,7 +556,13 @@ class AnnotatedHaps: | RefTracks(tracks=tracks) | HapsTracks(tracks=tracks) ): - return evolve(self, _recon=HapsTracks(haps.to_kind(RaggedSeqs), tracks)) + return evolve( + self, + _recon=HapsTracks( + haps.to_kind(RaggedSeqs), + tracks, # type: ignore + ), + ) case "annotated", Haps() as haps, _, Ref() | Haps(): return evolve(self, _recon=haps.to_kind(RaggedAnnotatedHaps)) @@ -546,7 +572,11 @@ class AnnotatedHaps: | HapsTracks(tracks=tracks) ): return evolve( - self, _recon=HapsTracks(haps.to_kind(RaggedAnnotatedHaps), tracks) + self, + _recon=HapsTracks( + haps.to_kind(RaggedAnnotatedHaps), + tracks, # type: ignore + ), ) case "variants", Haps() as haps, _, Ref() | Haps(): @@ -557,7 +587,11 @@ class AnnotatedHaps: | HapsTracks(tracks=tracks) ): return evolve( - self, _recon=HapsTracks(haps.to_kind(RaggedVariants), tracks) + self, + _recon=HapsTracks( + haps.to_kind(RaggedVariants), + tracks, # type: ignore + ), ) case k, s, t, r: @@ -613,15 +647,24 @@ def with_tracks( "Can't set dataset to return tracks because it has none to begin with." ) case t, _, tr, (Ref() as seqs) | RefTracks(seqs=seqs): - tr = tr.with_tracks(t).to_kind(_kind) + tr = tr.with_tracks(t).to_kind( + _kind, # type: ignore + ) recon = RefTracks(seqs=seqs, tracks=tr) return evolve(self, _tracks=tr, _recon=recon) case t, _, tr, (Haps() as seqs) | HapsTracks(haps=seqs): - tr = tr.with_tracks(t).to_kind(_kind) - recon = HapsTracks(haps=seqs, tracks=tr) + tr = tr.with_tracks(t).to_kind( + _kind, # type: ignore + ) + recon = HapsTracks( + haps=seqs, # type: ignore + tracks=tr, + ) return evolve(self, _tracks=tr, _recon=recon) case t, _, tr, Tracks(): - tr = tr.with_tracks(t).to_kind(_kind) + tr = tr.with_tracks(t).to_kind( + _kind, # type: ignore + ) return evolve(self, _tracks=tr, _recon=tr) case k, s, t, r: assert_never(k), assert_never(s), assert_never(t), assert_never(r) @@ -851,7 +894,7 @@ def __repr__(self) -> str: def subset_to( self, - regions: Idx | None = None, + regions: StrIdx | None = None, samples: StrIdx | None = None, ) -> Self: """Subset the dataset to specific regions and/or samples by index or a boolean mask. If regions or samples @@ -910,13 +953,10 @@ def subset_to( if regions is None and samples is None: return self - if regions is not None: - if isinstance(regions, pl.Series): - regions = regions.to_numpy() - if np.issubdtype(regions.dtype, np.bool_): - regions = np.nonzero(regions)[0] - elif not np.issubdtype(regions.dtype, np.integer): - raise ValueError("`regions` must be index-like or a boolean mask.") + if is_str_arr(regions) and self._idxer.r2i_map is None: + raise ValueError( + "Cannot subset to regions by name because no region name was set." + ) idxer = self._idxer.subset_to(regions=regions, samples=samples) @@ -981,7 +1021,7 @@ def haplotype_lengths( def n_variants( self, regions: Idx | None = None, - samples: Idx | str | Sequence[str] | None = None, + samples: StrIdx | None = None, ) -> NDArray[np.int32]: """The number of variants in the dataset for specified regions and samples. @@ -1026,7 +1066,7 @@ def n_variants( def n_intervals( self, regions: Idx | None = None, - samples: Idx | str | Sequence[str] | None = None, + samples: StrIdx | None = None, ) -> NDArray[np.int32]: """The number of intervals in the dataset for specified regions and samples. @@ -1315,7 +1355,7 @@ def to_dataloader( ) def __getitem__( - self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]] + self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx] ) -> ( Ragged[np.bytes_ | np.float32] | RaggedAnnotatedHaps @@ -1333,6 +1373,11 @@ def __getitem__( ..., ] ): + if is_str_arr(idx) and self._idxer.r2i_map is None: + raise ValueError( + "Cannot query regions by name because no region name was set." + ) + # (b) ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx) r_idx, _ = np.unravel_index(ds_idx, self.full_shape) @@ -1579,17 +1624,17 @@ def with_tracks( @overload def __getitem__( self: ArrayDataset[SEQ, None], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> SEQ: ... @overload def __getitem__( self: ArrayDataset[None, NDArray[np.float32]], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> NDArray[np.float32]: ... @overload def __getitem__( self: ArrayDataset[SEQ, NDArray[np.float32]], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> tuple[SEQ, NDArray[np.float32]]: ... @overload def __getitem__( @@ -1614,15 +1659,15 @@ def __getitem__( @overload def __getitem__( self: ArrayDataset[MaybeSEQ, RaggedIntervals], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RaggedIntervals | tuple[SEQ, RaggedIntervals]: ... @overload def __getitem__( self: ArrayDataset[MaybeSEQ, MaybeTRK], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> SEQ | NDArray[np.float32] | tuple[SEQ, NDArray[np.float32]]: ... def __getitem__( - self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]] + self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx] ) -> SEQ | TRK | tuple[SEQ, TRK]: return super().__getitem__(idx) # type: ignore @@ -1729,54 +1774,54 @@ def with_tracks( @overload def __getitem__( self: RaggedDataset[None, None], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> NoReturn: ... @overload def __getitem__( self: RaggedDataset[RSEQ, None], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RSEQ: ... @overload def __getitem__( self: RaggedDataset[None, Ragged[np.float32]], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> Ragged[np.float32]: ... @overload def __getitem__( self: RaggedDataset[RSEQ, Ragged[np.float32]], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> tuple[RSEQ, Ragged[np.float32]]: ... @overload def __getitem__( self: RaggedDataset[None, RaggedIntervals], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RaggedIntervals: ... @overload def __getitem__( self: RaggedDataset[RSEQ, RaggedIntervals], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> tuple[RSEQ, RaggedIntervals]: ... @overload def __getitem__( self: RaggedDataset[RSEQ, MaybeRTRK], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RSEQ | tuple[RSEQ, Ragged[np.float32]]: ... @overload def __getitem__( self: RaggedDataset[MaybeRSEQ, Ragged[np.float32]], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ... @overload def __getitem__( self: RaggedDataset[MaybeRSEQ, RaggedIntervals], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RaggedIntervals | tuple[RSEQ, RaggedIntervals]: ... @overload def __getitem__( self: RaggedDataset[MaybeRSEQ, MaybeRTRK], - idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]], + idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx], ) -> RSEQ | Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ... def __getitem__( - self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]] + self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx] ) -> RSEQ | RTRK | tuple[RSEQ, RTRK]: return super().__getitem__(idx) # type: ignore diff --git a/python/genvarloader/_dataset/_indexing.py b/python/genvarloader/_dataset/_indexing.py index 14b12b0..02c3a21 100644 --- a/python/genvarloader/_dataset/_indexing.py +++ b/python/genvarloader/_dataset/_indexing.py @@ -1,11 +1,11 @@ from collections.abc import Sequence -from typing import Literal, cast +from typing import Any, Literal, cast import numpy as np from attrs import define, evolve from hirola import HashTable from numpy.typing import NDArray -from typing_extensions import assert_never +from typing_extensions import TypeGuard, assert_never from .._types import Idx, StrIdx from .._utils import idx_like_to_array, is_dtype @@ -19,6 +19,8 @@ class DatasetIndexer: """Full map from input sample indices to on-disk sample indices.""" s2i_map: HashTable """Map from input sample names to on-disk sample indices.""" + r2i_map: HashTable | None = None + """Map from input region names to on-disk region indices.""" region_subset_idxs: NDArray[np.integer] | None = None """Which input regions are included in the subset.""" sample_subset_idxs: NDArray[np.integer] | None = None @@ -30,16 +32,29 @@ def from_region_and_sample_idxs( r_idxs: NDArray[np.integer], s_idxs: NDArray[np.integer], samples: list[str], + regions: list[str] | None = None, ): + if regions is not None: + _regions = np.array(regions) + r2i_map = HashTable( + max=len(_regions) * 2, # type: ignore | 2x size for perf > mem + dtype=_regions.dtype, + ) + r2i_map.add(_regions) + else: + r2i_map = None + _samples = np.array(samples) s2i_map = HashTable( max=len(_samples) * 2, # type: ignore | 2x size for perf > mem dtype=_samples.dtype, ) s2i_map.add(_samples) + return cls( full_region_idxs=r_idxs, full_sample_idxs=s_idxs, + r2i_map=r2i_map, s2i_map=s2i_map, ) @@ -84,7 +99,7 @@ def __len__(self): def subset_to( self, - regions: Idx | None = None, + regions: StrIdx | None = None, samples: StrIdx | None = None, ) -> "DatasetIndexer": """Subset the dataset to specific regions and/or samples.""" @@ -92,12 +107,13 @@ def subset_to( return self if samples is not None: - samples = self.s2i(samples) + samples = self.sample2idx(samples) sample_idxs = idx_like_to_array(samples, self.n_samples) else: sample_idxs = np.arange(self.n_samples, dtype=np.intp) if regions is not None: + regions = self.region2idx(regions) region_idxs = idx_like_to_array(regions, self.n_regions) else: region_idxs = np.arange(self.n_regions, dtype=np.intp) @@ -111,7 +127,7 @@ def to_full_dataset(self) -> "DatasetIndexer": return evolve(self, region_subset_idxs=None, sample_subset_idxs=None) def parse_idx( - self, idx: Idx | tuple[Idx] | tuple[Idx, StrIdx] + self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx] ) -> tuple[NDArray[np.integer], bool, tuple[int, ...] | None]: out_reshape = None squeeze = False @@ -125,23 +141,24 @@ def parse_idx( else: regions, samples = idx - s_idx = self.s2i(samples) - idx = (regions, s_idx) + r_idx = self.region2idx(regions) + s_idx = self.sample2idx(samples) + idx = (r_idx, s_idx) idx_t = idx_type(idx) if idx_t == "basic": if all(isinstance(i, (int, np.integer)) for i in idx): squeeze = True - r_idx = np.atleast_1d(self._r_idx[regions]) + r_idx = np.atleast_1d(self._r_idx[r_idx]) s_idx = np.atleast_1d(self._s_idx[s_idx]) idx = np.ravel_multi_index(np.ix_(r_idx, s_idx), self.full_shape).squeeze() if isinstance(regions, slice) and isinstance(samples, slice): out_reshape = (len(r_idx), len(s_idx)) elif idx_t == "adv": - r_idx = self._r_idx[regions] + r_idx = self._r_idx[r_idx] s_idx = self._s_idx[s_idx] idx = np.ravel_multi_index((r_idx, s_idx), self.full_shape) elif idx_t == "combo": - r_idx = self._r_idx[regions] + r_idx = self._r_idx[r_idx] s_idx = self._s_idx[s_idx] idx = np.ravel_multi_index( np.ix_(r_idx.ravel(), s_idx.ravel()), self.full_shape @@ -176,17 +193,26 @@ def _s_idx(self): return self.full_sample_idxs return self.full_sample_idxs[self.sample_subset_idxs] - def s2i(self, samples: StrIdx) -> Idx: + def sample2idx(self, samples: StrIdx) -> Idx: """Convert sample names to sample indices.""" return s2i(samples, self.s2i_map) + def region2idx(self, regions: StrIdx) -> Idx: + """Convert region names to region indices.""" + return s2i(regions, self.r2i_map) -def s2i(str_idx: StrIdx, map: HashTable) -> Idx: + +def s2i(str_idx: StrIdx, map: HashTable | None) -> Idx: """Convert a string index to an integer index using a hirola.HashTable.""" if not isinstance(str_idx, (np.ndarray, slice)): str_idx = np.asarray(str_idx) - if is_dtype(str_idx, np.str_) or is_dtype(str_idx, np.object_): + if is_str_arr(str_idx): + if map is None: + raise ValueError( + "Queries are names/strings, but no string-to-integer mapping is available." + ) + idx = map.get(str_idx) if (np.atleast_1d(idx) == -1).any(): raise KeyError( @@ -209,7 +235,7 @@ def idx_type( """Check if the index is a fancy index.""" if not isinstance(idx, tuple): idx = (idx,) - n_adv = sum(map(is_adv_idx, idx)) + n_adv = sum(map(lambda idx: isinstance(idx, (Sequence, np.ndarray)), idx)) if n_adv == 0: return "basic" elif n_adv == 1: @@ -220,6 +246,6 @@ def idx_type( raise ValueError(f"Invalid index type: {idx}") -def is_adv_idx(idx: Idx) -> bool: - """Check if the index is a fancy index.""" - return isinstance(idx, (Sequence, np.ndarray)) +def is_str_arr(obj: Any) -> TypeGuard[NDArray[np.str_] | NDArray[np.object_]]: + """Check if the object is a string array.""" + return is_dtype(obj, np.str_) or is_dtype(obj, np.object_) diff --git a/python/genvarloader/_dataset/_rag_variants.py b/python/genvarloader/_dataset/_rag_variants.py index f838719..5b8c232 100644 --- a/python/genvarloader/_dataset/_rag_variants.py +++ b/python/genvarloader/_dataset/_rag_variants.py @@ -77,6 +77,25 @@ def from_ak(cls, arr: ak.Array) -> RaggedVariants: if {"ref", "ilen"}.isdisjoint(fields): raise ValueError("Must have one of ref or ilen.") + def find_and_convert_to_ragged(content: Content, depth_context: dict, **kwargs): + if isinstance(content, (ListArray, ListOffsetArray)): + depth_context["n_varlen"] += 1 + + if ( + # is a varlen leaf + isinstance(content, (ListArray, ListOffsetArray)) + and isinstance(content.content, NumpyArray) + # is the only varlen leaf in this branch + and depth_context["n_varlen"] == 1 + # has no parameters that might conflict with Ragged + and len(content.parameters) == 0 + ): + return ak.with_parameter(content, "__list__", "Ragged", highlevel=False) + + arr = ak.transform( # type: ignore + find_and_convert_to_ragged, arr, depth_context={"n_varlen": 0} + ) + return ak.with_parameter(arr, "__record__", RaggedVariants.__name__) @property @@ -189,14 +208,22 @@ def rc_(self, to_rc: NDArray[np.bool_] | None = None) -> Self: The RaggedVariants object with the alleles reverse complemented. """ if to_rc is None: - to_rc = np.ones(self.shape[0], np.bool_) + to_rc = np.ones(self.shape[0], np.bool_) # type: ignore elif not to_rc.any(): return self - self["alt"] = ak.where(to_rc, reverse_complement(self["alt"]), self["alt"]) + self["alt"] = ak.where( + to_rc, + reverse_complement(self["alt"]), # type: ignore + self["alt"], + ) if "ref" in self.fields: - self["ref"] = ak.where(to_rc, reverse_complement(self["ref"]), self["ref"]) + self["ref"] = ak.where( + to_rc, + reverse_complement(self["ref"]), # type: ignore + self["ref"], + ) return self @@ -340,7 +367,11 @@ def _alleles_to_nested_tensor( offsets = offsets.offsets.data.astype(np.int32) # type: ignore lengths = np.diff(offsets) - max_alen = lengths.max().item() + if len(lengths) == 0: + max_alen = 0 + else: + max_alen = lengths.max().item() + offsets = torch.from_numpy(offsets) return _alleles, offsets, max_alen diff --git a/python/genvarloader/_ragged.py b/python/genvarloader/_ragged.py index 366944c..052537e 100644 --- a/python/genvarloader/_ragged.py +++ b/python/genvarloader/_ragged.py @@ -75,9 +75,9 @@ def squeeze(self, axis: int | tuple[int, ...] | None = None) -> RaggedIntervals: Axis or axes to squeeze. If None, all axes of length 1 are squeezed. """ return RaggedIntervals( - self.starts.squeeze(axis), - self.ends.squeeze(axis), - self.values.squeeze(axis), + self.starts.squeeze(axis), # type: ignore + self.ends.squeeze(axis), # type: ignore + self.values.squeeze(axis), # type: ignore ) def to_fixed_shape( @@ -286,7 +286,7 @@ def ufunc_comp_dna(seq: NDArray[np.uint8]) -> NDArray[np.uint8]: def _ak_comp_dna_helper(layout, **kwargs): if layout.is_numpy: return NumpyArray( - ufunc_comp_dna(layout.data), # type: ignoreF + ufunc_comp_dna(layout.data), # type: ignore parameters=layout.parameters, )