From 325c64e366c23ae0e6cd020a4735170ccb0e3860 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Tue, 2 Dec 2025 14:44:53 +0000 Subject: [PATCH] Allow metadata keys to be specified in samples() Fixes https://github.com/tskit-dev/tskit/issues/1697 --- python/CHANGELOG.rst | 5 +++ python/tests/test_highlevel.py | 82 +++++++++++++++++++++++++++++++++- python/tskit/trees.py | 39 ++++++++++++++-- 3 files changed, 122 insertions(+), 4 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index d11a2a0e2d..9f08c13b53 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,11 @@ [1.0.x] - YYYY-MM-DD -------------------- +**Features** + +- ``ts.samples(population=...)`` now accepts dictionaries to filter samples + by population metadata. (:user:`hyanwong`, :issue:`1697` :pr:`3345`) + **Bugfixes** - ``ts.samples(population=...)`` now raises a ``ValueError`` if the population diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 9ce5a928d2..f21a8326d9 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -998,11 +998,91 @@ def test_bad_samples(self, pop): with pytest.raises(ValueError, match="must be an integer ID"): ts.samples(population=pop) - @pytest.mark.parametrize("pop", [0, np.int32(0), np.int64(0), np.uint32(0)]) + @pytest.mark.parametrize( + "pop", [0, np.int32(0), np.int64(0), np.uint32(0), {"name": "pop_0"}, {}] + ) def test_good_samples(self, pop): ts = msprime.sim_ancestry(2) + assert ts.num_populations == 1 assert np.array_equiv(ts.samples(population=pop), ts.samples()) + @pytest.mark.parametrize( + "pop", + [ + {"name": "nonexistent"}, + {"name": "pop_0", "description": "nonexistent"}, + {"name": "pop_0", "nonexistent": ""}, + ], + ) + def test_samples_metadata_no_selected(self, pop): + ts = msprime.sim_ancestry(2) + with pytest.raises( + ValueError, match="No populations match the specified metadata" + ): + ts.samples(population=pop) + + @pytest.mark.parametrize("pop", [{"name": "pop_0"}, {}]) + def test_samples_metadata_nopop(self, pop): + ts = tskit.Tree.generate_balanced(4).tree_sequence + assert ts.num_populations == 0 + with pytest.raises( + ValueError, match="No populations match the specified metadata" + ): + ts.samples(population=pop) + + def test_samples_metadata_multipop(self): + demography = msprime.Demography() + demography.add_population(name="A", initial_size=10_000) + demography.add_population(name="B", initial_size=5_000) + demography.add_population(name="C", initial_size=1_000) + demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") + samples = {"A": 1, "B": 1} + ts = msprime.sim_ancestry(samples, demography=demography, random_seed=12) + with pytest.raises(ValueError, match=r"populations \(\[0, 1, 2\]\) match"): + ts.samples(population={"description": ""}) + + @pytest.mark.parametrize( + "pop_param", + [ + {"name": "B"}, + {"name": "B", "description": "A&B"}, + {"description": "A&B", "+": "B⊕C"}, + ], + ) + def test_samples_metadata_onepop(self, pop_param): + demography = msprime.Demography() + N = 100 + demography.add_population(name="A", description="A&B", initial_size=N) + demography.add_population( + name="B", description="A&B", extra_metadata={"+": "B⊕C"}, initial_size=N + ) + demography.add_population(name="C", extra_metadata={"+": "B⊕C"}, initial_size=N) + demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") + ts = msprime.sim_ancestry( + samples={"A": 1, "B": 1}, demography=demography, random_seed=12 + ) + samp = ts.samples(population=pop_param) + id_B = {pop.metadata["name"]: pop.id for pop in ts.populations()}["B"] + assert np.array_equiv(samp, ts.samples(population=id_B)) + + @pytest.mark.parametrize("md", [b"{}", b"", None]) + def test_bad_pop_metadata(self, md): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + tables.populations.add_row(metadata=md) + ts = tables.tree_sequence() + with pytest.raises(ValueError, match="metadata is not a dictionary"): + ts.samples(population={}) + + def test_empty_pop_metadata(self): + # The docs state "Tskit deviates from standard JSON in that + # empty metadata is interpreted as an empty object." - test this + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + tables.populations.add_row() + tables.populations.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.nodes.population = np.zeros_like(tables.nodes.population) # all in pop 0 + ts = tables.tree_sequence() + assert np.array_equiv(ts.samples(population={}), ts.samples()) + @pytest.mark.parametrize("time", [0, 0.1, 1 / 3, 1 / 4, 5 / 7]) def test_samples_time(self, time): ts = self.get_tree_sequence(num_demes=2, n=20, times=[time, 0.2, 1, 15]) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7775231dde..fd2da5543b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6514,15 +6514,32 @@ def samples(self, population=None, *, population_id=None, time=None): time is approximately equal to the specified time. If `time` is a pair of values of the form `(min_time, max_time)`, only return sample IDs whose node time `t` is in this interval such that `min_time <= t < max_time`. + If both `population` and `time` are specified, the returned samples + will satisfy both criteria. - :param int population: The population of interest. If None, do not - filter samples by population. + .. note:: + The population can be specified either by an integer (in which case + this is the population ID) or a dictionary matching information in the + population metadata. If a dictionary, it should contain key-value pair(s) + that match the metadata of the desired population; for instance, + ``population={'name': 'abc'}`` will select the population that has a + 'name' of 'abc' in metadata: there should be exactly one population + that has matching key-value pair(s), if not, an error is raised. + + :param Union[int, dict] population: The population of interest. If an + integer, this is the population ID. If a dictionary, the keys + in the dictionary specify metadata key-value pairs to match (see note + above). If None, do not filter samples by population. :param int population_id: Deprecated alias for ``population``. :param float,tuple time: The time or time interval of interest. If None, do not filter samples by time. :return: A numpy array of the node IDs for the samples of interest, listed in numerical order. :rtype: numpy.ndarray (dtype=np.int32) + :raises ValueError: If population or time is specified incorrectly. + :raises ValueError: If multiple or no populations match the specified metadata. + :raises ValueError: If a dictionary is specified to select a population + but existing population metadata entries cannot be treated as dictionaries. """ if population is not None and population_id is not None: raise ValueError( @@ -6533,8 +6550,24 @@ def samples(self, population=None, *, population_id=None, time=None): samples = self._ll_tree_sequence.get_samples() keep = np.full(shape=samples.shape, fill_value=True) if population is not None: + if isinstance(population, dict): + # look for the key names in the population metadata: we don't expect + # there to be many populations, so a simple loop is fine. + pops = [] + for pop in self.populations(): + if not isinstance(pop.metadata, dict): + raise ValueError("Population metadata is not a dictionary") + if set(population.items()).issubset(pop.metadata.items()): + pops.append(pop.id) + if len(pops) == 0: + raise ValueError("No populations match the specified metadata") + if len(pops) > 1: + raise ValueError( + f"Multiple populations ({pops}) match the specified metadata" + ) + population = pops[0] if not isinstance(population, numbers.Integral): - raise ValueError("`population` must be an integer ID") + raise ValueError("`population` must be an integer ID or a dictionary") population = int(population) sample_population = self.nodes_population[samples] keep = np.logical_and(keep, sample_population == population)