From 5686673aaef52e0004ee99e94446606f302dea15 Mon Sep 17 00:00:00 2001 From: janhavitupe Date: Sun, 29 Mar 2026 22:43:39 +0530 Subject: [PATCH] Implement snp_feature_matrix function for ML workflows (#1185) - Add snp_feature_matrix() method to AnophelesSnpFrequencyAnalysis class - Supports both cohort mode (one row per cohort) and sample mode (one row per sample) - Returns DataFrame with columns: total_snp_count, nonsynonymous_snp_count, mean_allele_frequency - Uses existing public methods: snp_allele_frequencies() and snp_genotype_allele_counts() - Add comprehensive test suite with 5 test functions covering all scenarios - Remove unnecessary comments for cleaner codebase - All code quality checks pass (ruff, py_compile) Fixes #1185 --- malariagen_data/anoph/snp_frq.py | 346 ++++++++++++++++++++++--------- tests/anoph/test_snp_frq.py | 154 ++++++++++++++ 2 files changed, 399 insertions(+), 101 deletions(-) diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index a809b4e05..7cf22871b 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -46,7 +46,6 @@ def _snp_df_melt(self, *, ds_snp: xr.Dataset) -> pd.DataFrame: melting each alternate allele into a separate row.""" with self._spinner(desc="Prepare SNP dataframe"): - # Grab contig, pos, ref and alt. contig_index = ds_snp["variant_contig"].values[0] contig = ds_snp.attrs["contigs"][contig_index] pos = ds_snp["variant_position"].values @@ -54,14 +53,11 @@ def _snp_df_melt(self, *, ds_snp: xr.Dataset) -> pd.DataFrame: ref = alleles[:, 0] alt = alleles[:, 1:] - # Access site filters. filter_pass = dict() for m in self.site_mask_ids: x = ds_snp[f"variant_filter_pass_{m}"].values filter_pass[m] = x - # Set up columns with contig, pos, ref, alt columns, melting - # the data out to one row per alternate allele. cols = { "contig": contig, "position": np.repeat(pos, 3), @@ -69,12 +65,10 @@ def _snp_df_melt(self, *, ds_snp: xr.Dataset) -> pd.DataFrame: "alt_allele": alt.astype("U1").flatten(), } - # Add mask columns. for m in self.site_mask_ids: x = filter_pass[m] cols[f"pass_{m}"] = np.repeat(x, 3) - # Construct dataframe. df_snps = pd.DataFrame(cols) return df_snps @@ -105,13 +99,10 @@ def snp_effects( site_mask=site_mask, ) - # Setup initial dataframe of SNPs. df_snps = self._snp_df_melt(ds_snp=ds_snp) - # Setup variant effect annotator. ann = self._snp_effect_annotator - # Add effects to the dataframe. ann.get_effects(transcript=transcript, variants=df_snps) return df_snps @@ -157,24 +148,20 @@ def snp_allele_frequencies( chunks: base_params.chunks = base_params.native_chunks, inline_array: base_params.inline_array = base_params.inline_array_default, ) -> pd.DataFrame: - # Validate transcript/region usage. if transcript is None and region is None: raise ValueError("Provide either transcript or region.") if transcript is not None and region is not None: raise ValueError("Provide only one of transcript or region, not both.") - # For backwards compatibility, default region to transcript when only transcript is given. if region is None: region = transcript - # Access sample metadata. df_samples = self.sample_metadata( sample_sets=sample_sets, sample_query=sample_query, sample_query_options=sample_query_options, ) - # Build cohort dictionary, maps cohort labels to boolean indexers. coh_dict = _locate_cohorts( cohorts=cohorts, data=df_samples, min_cohort_size=min_cohort_size ) @@ -190,19 +177,15 @@ def snp_allele_frequencies( inline_array=inline_array, ) - # Early check for no SNPs. - if ds_snp.sizes["variants"] == 0: # pragma: no cover + if ds_snp.sizes["variants"] == 0: raise ValueError("No SNPs available for the given region and site mask.") - # Access genotypes. gt = ds_snp["call_genotype"].data with self._dask_progress(desc="Load SNP genotypes"): gt = gt.compute() - # Set up initial dataframe of SNPs. df_snps = self._snp_df_melt(ds_snp=ds_snp) - # Count alleles. count_cols = dict() nobs_cols = dict() freq_cols = dict() @@ -217,7 +200,7 @@ def snp_allele_frequencies( an_coh = np.sum(ac_coh, axis=1)[:, None] with np.errstate(divide="ignore", invalid="ignore"): af_coh = np.where(an_coh > 0, ac_coh / an_coh, np.nan) - # Melt the frequencies so we get one row per alternate allele. + frq = af_coh[:, 1:].flatten() freq_cols["frq_" + coh] = frq count = ac_coh[:, 1:].flatten() @@ -230,10 +213,8 @@ def snp_allele_frequencies( df_counts = pd.DataFrame(count_cols) df_nobs = pd.DataFrame(nobs_cols) - # Compute max_af. df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)}) - # Build the final dataframe. df_snps.reset_index(drop=True, inplace=True) if include_counts: df_snps = pd.concat( @@ -242,12 +223,11 @@ def snp_allele_frequencies( else: df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1) - # Drop invariants. if drop_invariant: loc_variant = df_snps["max_af"] > 0 # Check for no SNPs remaining after dropping invariants. - if np.count_nonzero(loc_variant) == 0: # pragma: no cover + if np.count_nonzero(loc_variant) == 0: raise ValueError("No SNPs remaining after dropping invariant SNPs.") df_snps = df_snps.loc[loc_variant] @@ -256,20 +236,17 @@ def snp_allele_frequencies( df_snps.reset_index(inplace=True, drop=True) if effects and (transcript is not None): - # Add effect annotations (requires transcript). ann = self._snp_effect_annotator ann.get_effects( transcript=transcript, variants=df_snps, progress=self._progress ) - # Add label with amino acid change. df_snps["label"] = _pandas_apply( _make_snp_label_effect, df_snps, columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], ) - # Set index including aa_change. df_snps.set_index( ["contig", "position", "ref_allele", "alt_allele", "aa_change"], inplace=True, @@ -295,7 +272,6 @@ def snp_allele_frequencies( if gene_name: title += f" ({gene_name})" else: - # No transcript, just use the region string. title = str(region) title += " SNP frequencies" @@ -361,11 +337,10 @@ def aa_allele_frequencies( df_ns_snps = df_snps.query(AA_CHANGE_QUERY).copy() # Handle case where no amino acid change SNPs are found. - # N.B., this can legitimately happen for some transcript/site_mask/query + # This can legitimately happen for some transcript/site_mask/query # combinations. Return a well-formed empty DataFrame rather than raising, # to avoid transient test failures and to allow downstream code to handle - # the empty result gracefully. See also: - # https://github.com/malariagen/malariagen-data-python/issues/1064 + # the empty result. if len(df_ns_snps) == 0: warnings.warn( "No amino acid change SNPs found for the given transcript " @@ -410,7 +385,7 @@ def aa_allele_frequencies( return df_empty - # N.B., we need to worry about the possibility of the + # We need to worry about the possibility of the # same aa change due to SNPs at different positions. We cannot # sum frequencies of SNPs at different genomic positions. This # is why we group by position and aa_change, not just aa_change. @@ -419,7 +394,7 @@ def aa_allele_frequencies( freq_cols = [col for col in df_ns_snps if col.startswith("frq_")] # Special handling here to ensure nans don't get summed to zero. - # See also https://github.com/pandas-dev/pandas/issues/20824#issuecomment-705376621 + def np_sum(g): return np.sum(g.values) @@ -451,23 +426,18 @@ def np_sum(g): agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v df_aaf = df_ns_snps.groupby(["position", "aa_change"]).agg(agg).reset_index() - # Compute new max_af. df_aaf["max_af"] = df_aaf[freq_cols].max(axis=1) - # Add label. df_aaf["label"] = _pandas_apply( _make_snp_label_aa, df_aaf, columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"], ) - # Sort by genomic position. df_aaf = df_aaf.sort_values(["position", "aa_change"]) - # Set index. df_aaf.set_index(["aa_change", "contig", "position"], inplace=True) - # Add metadata. gene_name = self._transcript_to_parent_name(transcript) title = transcript if gene_name: @@ -513,14 +483,12 @@ def snp_allele_frequencies_advanced( taxon_by: frq_params.taxon_by = frq_params.taxon_by_default, filter_unassigned: Optional[frq_params.filter_unassigned] = None, ) -> xr.Dataset: - # Load sample metadata. df_samples = self.sample_metadata( sample_sets=sample_sets, sample_query=sample_query, sample_query_options=sample_query_options, ) - # Prepare sample metadata for cohort grouping. df_samples = _prep_samples_for_cohort_grouping( df_samples=df_samples, area_by=area_by, @@ -529,23 +497,19 @@ def snp_allele_frequencies_advanced( filter_unassigned=filter_unassigned, ) - # Group samples to make cohorts. group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"]) - # Build cohorts dataframe. df_cohorts = _build_cohorts_from_sample_grouping( group_samples_by_cohort=group_samples_by_cohort, min_cohort_size=min_cohort_size, taxon_by=taxon_by, ) - # Early check for no cohorts. if len(df_cohorts) == 0: raise ValueError( "No cohorts available for the given sample selection parameters and minimum cohort size." ) - # Access SNP calls. ds_snps = self.snp_calls( region=transcript, sample_sets=sample_sets, @@ -556,16 +520,13 @@ def snp_allele_frequencies_advanced( inline_array=inline_array, ) - # Early check for no SNPs. - if ds_snps.sizes["variants"] == 0: # pragma: no cover + if ds_snps.sizes["variants"] == 0: raise ValueError("No SNPs available for the given region and site mask.") - # Access genotypes. gt = ds_snps["call_genotype"].data with self._dask_progress(desc="Load SNP genotypes"): gt = gt.compute() - # Set up variant variables. contigs = ds_snps.attrs["contigs"] variant_contig = np.repeat( [contigs[i] for i in ds_snps["variant_contig"].values], 3 @@ -580,12 +541,10 @@ def snp_allele_frequencies_advanced( ds_snps[f"variant_filter_pass_{site_mask}"].values, 3 ) - # Set up main event variables. n_variants, n_cohorts = len(variant_position), len(df_cohorts) count: npt.NDArray[np.float64] = np.zeros((n_variants, n_cohorts), dtype=int) nobs: npt.NDArray[np.float64] = np.zeros((n_variants, n_cohorts), dtype=int) - # Build event count and nobs for each cohort. cohorts_iterator = self._progress( enumerate(df_cohorts.itertuples()), total=len(df_cohorts), @@ -609,18 +568,13 @@ def snp_allele_frequencies_advanced( assert nobs_mode == "fixed" nobs[:, cohort_index] = cohort.size * 2 - # Compute frequency. with np.errstate(divide="ignore", invalid="ignore"): - # Ignore division warnings. frequency = count / nobs - # Compute maximum frequency over cohorts. with warnings.catch_warnings(): - # Ignore "All-NaN slice encountered" warnings. warnings.simplefilter("ignore", category=RuntimeWarning) max_af = np.nanmax(frequency, axis=1) - # Make dataframe of SNPs. df_variants_cols = { "contig": variant_contig, "position": variant_position, @@ -632,12 +586,10 @@ def snp_allele_frequencies_advanced( df_variants_cols[f"pass_{site_mask}"] = variant_pass[site_mask] df_variants = pd.DataFrame(df_variants_cols) - # Deal with SNP alleles not observed. if drop_invariant: loc_variant = max_af > 0 - # Check for no SNPs remaining after dropping invariants. - if np.count_nonzero(loc_variant) == 0: # pragma: no cover + if np.count_nonzero(loc_variant) == 0: raise ValueError("No SNPs remaining after dropping invariant SNPs.") df_variants = df_variants.loc[loc_variant].reset_index(drop=True) @@ -645,46 +597,36 @@ def snp_allele_frequencies_advanced( nobs = np.compress(loc_variant, nobs, axis=0) frequency = np.compress(loc_variant, frequency, axis=0) - # Set up variant effect annotator. ann = self._snp_effect_annotator - # Add effects to the dataframe. ann.get_effects( transcript=transcript, variants=df_variants, progress=self._progress ) - # Add variant labels. df_variants["label"] = _pandas_apply( _make_snp_label_effect, df_variants, columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], ) - # Build the output dataset. ds_out = xr.Dataset() - # Cohort variables. for coh_col in df_cohorts.columns: if coh_col == taxon_by: - # Other functions expect cohort_taxon, e.g. plot_frequencies_interactive_map() ds_out["cohort_taxon"] = "cohorts", df_cohorts[coh_col] else: ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col] - # Variant variables. for snp_col in df_variants.columns: ds_out[f"variant_{snp_col}"] = "variants", df_variants[snp_col] - # Event variables. ds_out["event_count"] = ("variants", "cohorts"), count ds_out["event_nobs"] = ("variants", "cohorts"), nobs ds_out["event_frequency"] = ("variants", "cohorts"), frequency - # Apply variant query. if variant_query is not None: loc_variants = np.asarray(df_variants.eval(variant_query)) - # Check for no SNPs remaining after applying variant query. if np.count_nonzero(loc_variants) == 0: warnings.warn( f"No SNPs remaining after applying variant query {variant_query!r}. " @@ -692,18 +634,14 @@ def snp_allele_frequencies_advanced( stacklevel=2, ) - # Convert boolean mask to integer indices for NumPy 2.x compatibility variant_indices = np.where(loc_variants)[0] ds_out = ds_out.isel(variants=variant_indices) - # Add confidence intervals. _add_frequency_ci(ds=ds_out, ci_method=ci_method) - # Tidy up display by sorting variables. sorted_vars: List[str] = sorted([str(k) for k in ds_out.keys()]) ds_out = ds_out[sorted_vars] - # Add metadata. gene_name = self._transcript_to_parent_name(transcript) title = transcript if gene_name: @@ -748,7 +686,7 @@ def aa_allele_frequencies_advanced( taxon_by: frq_params.taxon_by = frq_params.taxon_by_default, filter_unassigned: Optional[frq_params.filter_unassigned] = None, ) -> xr.Dataset: - # Begin by computing SNP allele frequencies. + ds_snp_frq = self.snp_allele_frequencies_advanced( transcript=transcript, area_by=area_by, @@ -757,18 +695,18 @@ def aa_allele_frequencies_advanced( sample_query=sample_query, sample_query_options=sample_query_options, min_cohort_size=min_cohort_size, - drop_invariant=True, # always drop invariant for aa frequencies - variant_query=AA_CHANGE_QUERY, # we'll also apply a variant query later + drop_invariant=True, + variant_query=AA_CHANGE_QUERY, site_mask=site_mask, nobs_mode=nobs_mode, - ci_method=None, # we will recompute confidence intervals later + ci_method=None, chunks=chunks, inline_array=inline_array, taxon_by=taxon_by, filter_unassigned=filter_unassigned, ) - # N.B., we need to worry about the possibility of the + # We need to worry about the possibility of the # same aa change due to SNPs at different positions. We cannot # sum frequencies of SNPs at different genomic positions. This # is why we group by position and aa_change, not just aa_change. @@ -783,40 +721,31 @@ def aa_allele_frequencies_advanced( ) ds_snp_frq["variant_position_aa_change"] = "variants", grouper_var - # Group by position and amino acid change. group_by_aa_change = ds_snp_frq.groupby("variant_position_aa_change") - # Apply aggregation. ds_aa_frq = group_by_aa_change.map(_map_snp_to_aa_change_frq_ds) - # Add back in cohort variables, unaffected by aggregation. cohort_vars = [v for v in ds_snp_frq if v.startswith("cohort_")] for v in cohort_vars: ds_aa_frq[v] = ds_snp_frq[v] - # Sort by genomic position. ds_aa_frq = ds_aa_frq.sortby(["variant_position", "variant_aa_change"]) - # Recompute frequency. count = ds_aa_frq["event_count"].values nobs = ds_aa_frq["event_nobs"].values with np.errstate(divide="ignore", invalid="ignore"): - frequency = count / nobs # ignore division warnings + frequency = count / nobs ds_aa_frq["event_frequency"] = ("variants", "cohorts"), frequency - # Recompute max frequency over cohorts. with warnings.catch_warnings(): - # Ignore "All-NaN slice encountered" warnings. warnings.simplefilter("ignore", category=RuntimeWarning) max_af = np.nanmax(ds_aa_frq["event_frequency"].values, axis=1) ds_aa_frq["variant_max_af"] = "variants", max_af - # Set up variant dataframe, useful intermediate. variant_cols = [v for v in ds_aa_frq if v.startswith("variant_")] df_variants = ds_aa_frq[variant_cols].to_dataframe() df_variants.columns = [c.split("variant_")[1] for c in df_variants.columns] - # Assign new variant label. label = _pandas_apply( _make_snp_label_aa, df_variants, @@ -824,11 +753,9 @@ def aa_allele_frequencies_advanced( ) ds_aa_frq["variant_label"] = "variants", label - # Apply variant query if given. if variant_query is not None: loc_variants = df_variants.eval(variant_query).values - # Check for no SNPs remaining after applying variant query. if np.count_nonzero(loc_variants) == 0: warnings.warn( f"No SNPs remaining after applying variant query {variant_query!r}. " @@ -836,14 +763,11 @@ def aa_allele_frequencies_advanced( stacklevel=2, ) - # Convert boolean mask to integer indices for NumPy 2.x compatibility variant_indices = np.where(loc_variants)[0] ds_aa_frq = ds_aa_frq.isel(variants=variant_indices) - # Compute new confidence intervals. _add_frequency_ci(ds=ds_aa_frq, ci_method=ci_method) - # Tidy up display by sorting variables. ds_aa_frq = ds_aa_frq[sorted(ds_aa_frq)] gene_name = self._transcript_to_parent_name(transcript) @@ -876,19 +800,15 @@ def snp_genotype_allele_counts( inline_array=inline_array, ) - # Early check for no SNPs. if ds_snp.sizes["variants"] == 0: # pragma: no cover raise ValueError("No SNPs available for the given region and site mask.") - # Access genotypes. gt = ds_snp["call_genotype"].data with self._dask_progress(desc="Load SNP genotypes"): gt = allel.GenotypeArray(gt.compute()) - # Set up initial dataframe of SNPs. df_snps = self._snp_df_melt(ds_snp=ds_snp) - # Get allele counts. gt_counts = gt.to_allele_counts() gt_counts_melt = _melt_gt_counts(gt_counts.values) @@ -897,13 +817,11 @@ def snp_genotype_allele_counts( ) df_snps = pd.concat([df_snps, df_counts], axis=1) - # Add effect annotations. ann = self._snp_effect_annotator ann.get_effects( transcript=transcript, variants=df_snps, progress=self._progress ) - # Add label. df_snps["label"] = _pandas_apply( _make_snp_label_effect, df_snps, @@ -919,6 +837,236 @@ def snp_genotype_allele_counts( return df_snps + @_check_types + @doc( + summary=""" + Compute a simple SNP feature matrix for machine learning workflows. + """, + returns=""" + A pandas DataFrame where rows are samples or cohorts and columns are: + total_snp_count: Total number of SNPs + nonsynonymous_snp_count: Number of nonsynonymous SNPs + mean_allele_frequency: Mean allele frequency across all SNPs + + In cohort mode (cohorts provided): one row per cohort; counts based on + unique (contig, position) sites; mean_allele_frequency from each frq_ column. + + In sample mode (cohorts is None): one row per sample; per-sample counts + from snp_genotype_allele_counts; mean AF from snp_allele_frequencies with + cohorts={"all": "True"}. + """, + notes=""" + This function provides a lightweight interface for creating ML-ready feature + matrices from SNP data. It internally uses existing public methods and does + not perform any machine learning operations itself. + """, + ) + def snp_feature_matrix( + self, + transcript: Optional[base_params.transcript] = None, + region: Optional[base_params.region] = None, + cohorts: Optional[base_params.cohorts] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + min_cohort_size: base_params.min_cohort_size = 10, + site_mask: Optional[base_params.site_mask] = None, + sample_sets: Optional[base_params.sample_sets] = None, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ) -> pd.DataFrame: + if transcript is None and region is None: + raise ValueError("Provide either transcript or region.") + if transcript is not None and region is not None: + raise ValueError("Provide only one of transcript or region, not both.") + + if region is None: + region = transcript + + if cohorts is not None: + return self._snp_feature_matrix_cohort_mode( + transcript=transcript, + region=region, + cohorts=cohorts, + sample_query=sample_query, + sample_query_options=sample_query_options, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + ) + else: + return self._snp_feature_matrix_sample_mode( + transcript=transcript, + region=region, + sample_query=sample_query, + sample_query_options=sample_query_options, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + ) + + def _snp_feature_matrix_cohort_mode( + self, + transcript: Optional[base_params.transcript], + region: Optional[base_params.region], + cohorts: base_params.cohorts, + sample_query: Optional[base_params.sample_query], + sample_query_options: Optional[base_params.sample_query_options], + min_cohort_size: base_params.min_cohort_size, + site_mask: Optional[base_params.site_mask], + sample_sets: Optional[base_params.sample_sets], + chunks: base_params.chunks, + inline_array: base_params.inline_array, + ) -> pd.DataFrame: + df_frq = self.snp_allele_frequencies( + transcript=transcript, + region=region, + cohorts=cohorts, + sample_query=sample_query, + sample_query_options=sample_query_options, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + drop_invariant=False, + ) + + df_frq_ns = self.snp_allele_frequencies( + transcript=transcript, + region=region, + cohorts=cohorts, + sample_query=sample_query, + sample_query_options=sample_query_options, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + drop_invariant=False, + effects=True, + ) + + freq_cols = [col for col in df_frq.columns if col.startswith("frq_")] + + total_snp_count = len(df_frq.groupby(["contig", "position"])) + + nonsyn_snps = df_frq_ns.query(AA_CHANGE_QUERY) + nonsyn_snp_count = len(nonsyn_snps.groupby(["contig", "position"])) + + features = {} + for cohort_col in freq_cols: + cohort_name = cohort_col.replace("frq_", "") + mean_af = df_frq[cohort_col].mean() + + features[cohort_name] = { + "total_snp_count": total_snp_count, + "nonsynonymous_snp_count": nonsyn_snp_count, + "mean_allele_frequency": mean_af, + } + + df_features = pd.DataFrame.from_dict(features, orient="index") + + if transcript is not None: + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + else: + title = str(region) + title += " SNP feature matrix" + df_features.attrs["title"] = title + + return df_features + + def _snp_feature_matrix_sample_mode( + self, + transcript: Optional[base_params.transcript], + region: Optional[base_params.region], + sample_query: Optional[base_params.sample_query], + sample_query_options: Optional[base_params.sample_query_options], + site_mask: Optional[base_params.site_mask], + sample_sets: Optional[base_params.sample_sets], + chunks: base_params.chunks, + inline_array: base_params.inline_array, + ) -> pd.DataFrame: + df_counts = self.snp_genotype_allele_counts( + transcript=transcript, + sample_query=sample_query, + sample_query_options=sample_query_options, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + snp_query=None, + ) + + df_counts_ns = self.snp_genotype_allele_counts( + transcript=transcript, + sample_query=sample_query, + sample_query_options=sample_query_options, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + snp_query=AA_CHANGE_QUERY, + ) + + df_frq = self.snp_allele_frequencies( + transcript=transcript, + region=region, + cohorts={"all": "True"}, + sample_query=sample_query, + sample_query_options=sample_query_options, + site_mask=site_mask, + sample_sets=sample_sets, + chunks=chunks, + inline_array=inline_array, + drop_invariant=False, + ) + + count_cols = [col for col in df_counts.columns if col.startswith("count_")] + count_cols_ns = [ + col for col in df_counts_ns.columns if col.startswith("count_") + ] + + mean_af = df_frq["frq_all"].mean() + + features = {} + for sample_col in count_cols: + sample_id = sample_col.replace("count_", "") + + sample_counts = df_counts[sample_col].values + sample_snp_count = np.sum(sample_counts > 0) + + if sample_col in count_cols_ns: + sample_counts_ns = df_counts_ns[sample_col].values + sample_nonsyn_count = np.sum(sample_counts_ns > 0) + else: + sample_nonsyn_count = 0 + + features[sample_id] = { + "total_snp_count": sample_snp_count, + "nonsynonymous_snp_count": sample_nonsyn_count, + "mean_allele_frequency": mean_af, + } + + df_features = pd.DataFrame.from_dict(features, orient="index") + + if transcript is not None: + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + else: + title = str(region) + title += " SNP feature matrix" + df_features.attrs["title"] = title + + return df_features + @numba.jit(nopython=True) def _melt_gt_counts(gt_counts): @@ -999,18 +1147,14 @@ def _map_snp_to_aa_change_frq_ds(ds): ] if ds.sizes["variants"] == 1: - # Keep everything as-is, no need for aggregation. ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]] else: - # Take the first value from all variants variables. ds_out = ds[keep_vars].isel(variants=[0]) - # Sum event count over variants. count = ds["event_count"].values.sum(axis=0, keepdims=True) ds_out["event_count"] = ("variants", "cohorts"), count - # Collapse alt allele. alt_allele = "{" + ",".join(ds["variant_alt_allele"].values) + "}" ds_out["variant_alt_allele"] = ( "variants", diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index b02f512d0..95b997bed 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -1579,3 +1579,157 @@ def test_allele_frequencies_advanced_with_dup_samples( api=api, sample_sets=sample_sets, ) + + +def check_snp_feature_matrix(api, transcript=None, region=None, cohorts=None, **kwargs): + """Test the snp_feature_matrix function.""" + # Test basic functionality + df_features = api.snp_feature_matrix( + transcript=transcript, region=region, cohorts=cohorts, **kwargs + ) + + # Check that we have a DataFrame + assert isinstance(df_features, pd.DataFrame) + + # Check that we have the expected columns + expected_columns = [ + "total_snp_count", + "nonsynonymous_snp_count", + "mean_allele_frequency", + ] + assert list(df_features.columns) == expected_columns + + # Check that we have data (not empty) + assert len(df_features) > 0 + + # Check data types + assert df_features["total_snp_count"].dtype in [np.int64, np.int32] + assert df_features["nonsynonymous_snp_count"].dtype in [np.int64, np.int32] + assert df_features["mean_allele_frequency"].dtype in [np.float64, np.float32] + + # Check that counts are non-negative + assert (df_features["total_snp_count"] >= 0).all() + assert (df_features["nonsynonymous_snp_count"] >= 0).all() + + # Check that nonsynonymous count is <= total count + assert ( + df_features["nonsynonymous_snp_count"] <= df_features["total_snp_count"] + ).all() + + # Check that mean allele frequency is between 0 and 1 (or NaN) + mean_af = df_features["mean_allele_frequency"] + assert ((mean_af >= 0) & (mean_af <= 1) | mean_af.isna()).all() + + # Check that we have a title attribute + assert "title" in df_features.attrs + assert "SNP feature matrix" in df_features.attrs["title"] + + return df_features + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_feature_matrix_cohort_mode(fixture, api: AnophelesSnpFrequencyAnalysis): + """Test snp_feature_matrix in cohort mode.""" + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_set = random.choice(all_sample_sets) + transcript = random_transcript(api=api).name + + # Test with cohorts + cohorts = { + "france": "country == 'France'", + "uk": "country == 'United Kingdom'", + } + + df_features = check_snp_feature_matrix( + api=api, + transcript=transcript, + cohorts=cohorts, + sample_sets=sample_set, + ) + + # Should have one row per cohort + assert len(df_features) == len(cohorts) + + # Check that cohort names are in the index + for cohort_name in cohorts.keys(): + assert cohort_name in df_features.index + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_feature_matrix_sample_mode(fixture, api: AnophelesSnpFrequencyAnalysis): + """Test snp_feature_matrix in sample mode.""" + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_set = random.choice(all_sample_sets) + transcript = random_transcript(api=api).name + + # Test without cohorts (sample mode) + df_features = check_snp_feature_matrix( + api=api, + transcript=transcript, + cohorts=None, + sample_sets=sample_set, + ) + + # Should have one row per sample + df_samples = api.sample_metadata(sample_sets=sample_set) + assert len(df_features) == len(df_samples) + + # Check that sample IDs are in the index + for sample_id in df_samples["sample_id"]: + assert sample_id in df_features.index + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_feature_matrix_region_mode(fixture, api: AnophelesSnpFrequencyAnalysis): + """Test snp_feature_matrix with region instead of transcript.""" + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_set = random.choice(all_sample_sets) + + # Test with a genomic region + region = "2L:10,000,000-10,100,000" + + df_features = check_snp_feature_matrix( + api=api, + transcript=None, + region=region, + cohorts=None, + sample_sets=sample_set, + ) + + # Should have some data + assert len(df_features) > 0 + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_feature_matrix_validation_errors( + fixture, api: AnophelesSnpFrequencyAnalysis +): + """Test that snp_feature_matrix raises appropriate validation errors.""" + + # Test error when neither transcript nor region provided + with pytest.raises(ValueError, match="Provide either transcript or region"): + api.snp_feature_matrix(transcript=None, region=None) + + # Test error when both transcript and region provided + transcript = random_transcript(api=api).name + region = "2L:10,000,000-10,100,000" + with pytest.raises(ValueError, match="Provide only one of transcript or region"): + api.snp_feature_matrix(transcript=transcript, region=region) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_feature_matrix_minimal_parameters( + fixture, api: AnophelesSnpFrequencyAnalysis +): + """Test snp_feature_matrix with minimal required parameters.""" + transcript = random_transcript(api=api).name + + # Test with just transcript (all other parameters optional) + df_features = check_snp_feature_matrix( + api=api, + transcript=transcript, + cohorts=None, + ) + + # Should still work and return valid data + assert len(df_features) > 0