diff --git a/malariagen_data/anoph/association.py b/malariagen_data/anoph/association.py new file mode 100644 index 000000000..cdb63be11 --- /dev/null +++ b/malariagen_data/anoph/association.py @@ -0,0 +1,137 @@ +from typing import Optional, Dict, Any +import numpy as np +import scipy.stats + +from numpydoc_decorator import doc # type: ignore + +from . import base_params, phenotype_params +from .phenotypes import AnophelesPhenotypeData +from ..util import _check_types, Region + + +class AnophelesAssociationAnalysis(AnophelesPhenotypeData): + """ + Provides methods for testing statistical associations between + specific variants and phenotypic traits. + Inherited by AnophelesDataResource subclasses (e.g., Ag3). + """ + + def __init__(self, **kwargs): + # Cooperatively initialize parent classes + super().__init__(**kwargs) + + @_check_types + @doc( + summary="Test for association between a specific variant and a binary phenotype.", + parameters=dict( + position="The 1-based coordinate of the variant.", + ), + returns=dict( + stats="A dictionary containing the Fisher's Exact test odds ratio, p-value, and contingency table counts." + ), + ) + def variant_association( + self, + region: base_params.region, + position: int, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + insecticide: Optional[phenotype_params.insecticide] = None, + dose: Optional[phenotype_params.dose] = None, + phenotype: Optional[phenotype_params.phenotype] = None, + cohort_size: Optional[base_params.cohort_size] = None, + min_cohort_size: Optional[base_params.min_cohort_size] = None, + max_cohort_size: Optional[base_params.max_cohort_size] = None, + ) -> Dict[str, Any]: + """ + Extract phenotype and genotype data for a specific variant position and apply + Fisher's Exact Test to determine the statistical association between possessing + an alternate allele and the phenotype. + """ + # Parse the region to ensure we only pull the exact variant coordinate + # Fetching an entire chromosome of SNPs (e.g., '2L') would be extremely slow! + r = Region(region) + target_region = f"{r.contig}:{position}-{position}" + + # Fetch the merged multidimensional xarray for the target coordinate + ds = self.phenotypes_with_snps( + region=target_region, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + ) + + if "samples" not in ds.sizes or ds.sizes["samples"] == 0: + raise ValueError("No matching records found for the given criteria.") + + # If insecticide/dose/phenotype filters were provided, apply them + # (Alternatively, users can embed these in sample_query) + valid_indices = np.ones(ds.sizes["samples"], dtype=bool) + if insecticide is not None: + valid_indices &= ds["insecticide"].values == insecticide + if dose is not None: + valid_indices &= ds["dose"].values == dose + if phenotype is not None: + valid_indices &= ds["phenotype"].values.astype(str) == str(phenotype) + + # Check if the variant_position exists in the extracted region + var_positions = ds["variant_position"].values + pos_mask = var_positions == position + if not np.any(pos_mask): + raise ValueError( + f"Variant position {position} not found in region {region}." + ) + + # Sub-select data arrays + phenos = ds["phenotype_binary"].values[valid_indices] + # shape is (variants, samples, ploidy) + # Select specifically the row for `position` + gt = ds["call_genotype"].values[pos_mask][0] # shape (samples, ploidy) + gt = gt[valid_indices] + + # Ignore missing phenotypes (NaN) and missing calls (-1) + valid_mask = ~np.isnan(phenos) & (gt.min(axis=1) >= 0) + phenos_valid = phenos[valid_mask] + gt_valid = gt[valid_mask] + + # Define 2x2 categorical buckets + # "Has Alt": True if any allele in the genotype call is > 0 (e.g. 0/1 or 1/1) + has_alt = (gt_valid > 0).any(axis=1) + has_ref = ~has_alt # (i.e. entirely 0/0) + + pheno_positive = phenos_valid == 1 + pheno_negative = phenos_valid == 0 + + # Build Contingency Table: + # Alt Ref + # Pos a b + # Neg c d + a = np.sum(pheno_positive & has_alt) + b = np.sum(pheno_positive & has_ref) + c = np.sum(pheno_negative & has_alt) + d = np.sum(pheno_negative & has_ref) + + table = [[a, b], [c, d]] + res = scipy.stats.fisher_exact(table, alternative="two-sided") + + # In newer scipy versions (1.7+): res.statistic is OR, res.pvalue is P-val + # Support older/newer scipy return tuples safely + odds_ratio = res[0] + p_value = res[1] + + return { + "region": region, + "position": position, + "contingency_table": table, + "phenotype_positive_alt": int(a), + "phenotype_positive_ref": int(b), + "phenotype_negative_alt": int(c), + "phenotype_negative_ref": int(d), + "odds_ratio": float(odds_ratio), + "p_value": float(p_value), + "total_valid_samples": int(len(phenos_valid)), + } diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 84e2eb969..e5462d709 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -47,6 +47,7 @@ from .anoph.describe import AnophelesDescribe from .anoph.dipclust import AnophelesDipClustAnalysis from .anoph.heterozygosity import AnophelesHetAnalysis +from .anoph.association import AnophelesAssociationAnalysis from .util import ( CacheMiss, Region, # noqa: F401 (re-exported via __init__.py) @@ -78,6 +79,7 @@ # work around pycharm failing to recognise that doc() is callable # noinspection PyCallingNonCallable class AnophelesDataResource( + AnophelesAssociationAnalysis, AnophelesDipClustAnalysis, AnophelesHapClustAnalysis, AnophelesH1XAnalysis, diff --git a/tests/anoph/test_association.py b/tests/anoph/test_association.py new file mode 100644 index 000000000..47084cf06 --- /dev/null +++ b/tests/anoph/test_association.py @@ -0,0 +1,113 @@ +import pytest +import numpy as np +import xarray as xr +from malariagen_data.anoph.association import AnophelesAssociationAnalysis + + +class DummyAssociationAPI(AnophelesAssociationAnalysis): + def __init__(self): + # Skip cooperative multiple inheritance for pure logic testing + pass + + +def test_variant_association_logic(): + """ + Test the statistical logic of variant_association by injecting a + mock xarray Dataset with perfectly known phenotypic and genotypic correlations. + """ + api = DummyAssociationAPI() + + # Let's mock the exact structure returned by phenotypes_with_snps() + def mock_phenotypes_with_snps( + region, + sample_sets=None, + sample_query=None, + sample_query_options=None, + cohort_size=None, + min_cohort_size=None, + max_cohort_size=None, + ): + # We will create 10 samples, 1 variant at position 1000 + # 5 samples ALIVE (1), 5 samples DEAD (0) + # All ALIVE samples have alternate genotype (1/1) -> 2 + # All DEAD samples have reference genotype (0/0) -> 0 + samples = [f"SAM{i:03d}" for i in range(10)] + phenotype_binary = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=float) + insecticide = np.array(["Permethrin"] * 10) + dose = np.array(["1x"] * 10) + phenotype = np.array(["alive"] * 5 + ["dead"] * 5) + + call_genotype = np.array( + [ + [ + # Sample 0..4 (Alive, has alt) + [1, 1], + [0, 1], + [1, 0], + [1, 1], + [1, 1], + # Sample 5..9 (Dead, ref only) + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ] + ], + dtype="i1", + ) + + variant_position = np.array([1000]) + + ds = xr.Dataset( + { + "phenotype_binary": (["samples"], phenotype_binary), + "insecticide": (["samples"], insecticide), + "dose": (["samples"], dose), + "phenotype": (["samples"], phenotype), + "call_genotype": (["variants", "samples", "ploidy"], call_genotype), + "variant_position": (["variants"], variant_position), + }, + coords={ + "samples": samples, + }, + ) + return ds + + # Patch the method on the instance + api.phenotypes_with_snps = mock_phenotypes_with_snps + + # Now call the mathematical feature + res = api.variant_association(region="2L", position=1000, insecticide="Permethrin") + + # In our exact scenario: + # alt_alive = 5 (a) + # ref_alive = 0 (b) + # alt_dead = 0 (c) + # ref_dead = 5 (d) + + assert res["region"] == "2L" + assert res["position"] == 1000 + assert res["phenotype_positive_alt"] == 5 + assert res["phenotype_positive_ref"] == 0 + assert res["phenotype_negative_alt"] == 0 + assert res["phenotype_negative_ref"] == 5 + assert res["total_valid_samples"] == 10 + + # With a perfect split of 5/5, the Odds Ratio is technically infinity, + # and the P-value should be extremely significant (< 0.05) + assert res["p_value"] < 0.05 + + +def test_variant_association_not_found(): + api = DummyAssociationAPI() + + def mock_empty(*args, **kwargs): + return xr.Dataset( + {"variant_position": (["variants"], [1000])}, coords={"samples": ["SAM001"]} + ) + + api.phenotypes_with_snps = mock_empty + + with pytest.raises(ValueError, match="Variant position 9999 not found"): + api.variant_association(region="2L", position=9999)