From 15fd8496601757d105c03bf4c15e7fa614dfcf88 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 1 Dec 2025 09:53:29 +0000 Subject: [PATCH] Raise error if pop passed to samples is not an integer Co-authored-by: Peter Ralph --- python/CHANGELOG.rst | 10 ++++++++++ python/tests/test_highlevel.py | 11 +++++++++++ python/tskit/trees.py | 3 +++ 3 files changed, 24 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 072d4af498..d11a2a0e2d 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,3 +1,13 @@ +-------------------- +[1.0.x] - YYYY-MM-DD +-------------------- + +**Bugfixes** + +- ``ts.samples(population=...)`` now raises a ``ValueError`` if the population + ID is e.g. a population name, rather than silently returning no samples. + (:user:`hyanwong`, :pr:`3344`) + -------------------- [1.0.0] - 2025-11-27 -------------------- diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index eb1f0ed4ab..9ce5a928d2 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -992,6 +992,17 @@ def test_samples(self): ] assert total == ts.num_samples + @pytest.mark.parametrize("pop", ["string", "", "0", np.arange(2), 0.0, 0.5, np.nan]) + def test_bad_samples(self, pop): + ts = tskit.Tree.generate_balanced(4).tree_sequence + with pytest.raises(ValueError, match="must be an integer ID"): + ts.samples(population=pop) + + @pytest.mark.parametrize("pop", [0, np.int32(0), np.int64(0), np.uint32(0)]) + def test_good_samples(self, pop): + ts = msprime.sim_ancestry(2) + assert np.array_equiv(ts.samples(population=pop), ts.samples()) + @pytest.mark.parametrize("time", [0, 0.1, 1 / 3, 1 / 4, 5 / 7]) def test_samples_time(self, time): ts = self.get_tree_sequence(num_demes=2, n=20, times=[time, 0.2, 1, 15]) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2756d005e8..9210923bbe 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6533,6 +6533,9 @@ def samples(self, population=None, *, population_id=None, time=None): samples = self._ll_tree_sequence.get_samples() keep = np.full(shape=samples.shape, fill_value=True) if population is not None: + if not isinstance(population, numbers.Integral): + raise ValueError("`population` must be an integer ID") + population = int(population) sample_population = self.nodes_population[samples] keep = np.logical_and(keep, sample_population == population) if time is not None: