From 88efe1c47b2eaed55396e8c2854395f444a1cdb2 Mon Sep 17 00:00:00 2001 From: Daniel Goldstein Date: Fri, 22 May 2020 16:10:31 -0400 Subject: [PATCH] Implement algorithm for calculating species tree topology distributions. --- docs/python-api.rst | 2 + python/tests/test_combinatorics.py | 540 +++++++++++++++++++++++++++++ python/tskit/__init__.py | 1 + python/tskit/combinatorics.py | 255 +++++++++++++- python/tskit/trees.py | 74 ++++ 5 files changed, 866 insertions(+), 6 deletions(-) diff --git a/docs/python-api.rst b/docs/python-api.rst index 6a7439eced..db87897c30 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -502,6 +502,8 @@ See :ref:`sec_combinatorics` for details. .. autofunction:: tskit.all_tree_labellings +.. autoclass:: tskit.TopologyCounter + ********************** Linkage disequilibrium ********************** diff --git a/python/tests/test_combinatorics.py b/python/tests/test_combinatorics.py index 442515790f..cd45bc5673 100644 --- a/python/tests/test_combinatorics.py +++ b/python/tests/test_combinatorics.py @@ -23,11 +23,18 @@ """ Test cases for combinatorial algorithms. """ +import collections +import io import itertools import unittest +import msprime +import numpy as np + +import tests.test_wright_fisher as wf import tskit import tskit.combinatorics as comb +from tests import test_stats from tskit.combinatorics import RankTree @@ -351,6 +358,24 @@ def test_to_from_tsk_tree(self): self.assertTrue(tree.is_canonical()) self.assertEqual(tree, reconstructed) + def test_from_unary_tree(self): + tables = tskit.TableCollection(sequence_length=1) + c = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + p = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=1, parent=p, child=c) + + t = tables.tree_sequence().first() + with self.assertRaises(ValueError): + RankTree.from_tsk_tree(t) + + def test_to_tsk_tree_errors(self): + alpha_tree = RankTree.unrank((0, 0), 3, ["A", "B", "C"]) + out_of_bounds_tree = RankTree.unrank((0, 0), 3, [2, 3, 4]) + with self.assertRaises(ValueError): + alpha_tree.to_tsk_tree() + with self.assertRaises(ValueError): + out_of_bounds_tree.to_tsk_tree() + def test_rank_errors_multiple_roots(self): tables = tskit.TableCollection(sequence_length=1.0) @@ -420,3 +445,518 @@ def test_is_symmetrical(self): self.assertFalse(three_leaf_asym.is_symmetrical()) six_leaf_sym = RankTree(children=[three_leaf_asym, three_leaf_asym]) self.assertTrue(six_leaf_sym.is_symmetrical()) + + +class TestPartialTopologyCounter(unittest.TestCase): + def test_add_sibling_topologies_simple(self): + a = RankTree(children=[], label="A") + b = RankTree(children=[], label="B") + ab = RankTree(children=[a, b]) + + a_counter = comb.TopologyCounter() + a_counter["A"][a.rank()] = 1 + self.assertEqual(a_counter, comb.TopologyCounter.from_sample("A")) + + b_counter = comb.TopologyCounter() + b_counter["B"][b.rank()] = 1 + self.assertEqual(b_counter, comb.TopologyCounter.from_sample("B")) + + partial_counter = comb.PartialTopologyCounter() + partial_counter.add_sibling_topologies(a_counter) + partial_counter.add_sibling_topologies(b_counter) + + expected = comb.TopologyCounter() + expected["A"][a.rank()] = 1 + expected["B"][b.rank()] = 1 + expected["A", "B"][ab.rank()] = 1 + joined_counter = partial_counter.join_all_combinations() + self.assertEqual(joined_counter, expected) + + def test_add_sibling_topologies_polytomy(self): + """ + Goes through the topology-merging step at the root + of this tree: + | + | + +----+-----+----+ + | | | | + | | | | + | | | +---+ + | | | | | + | | | | | + A A B A C + """ + partial_counter = comb.PartialTopologyCounter() + a = RankTree(children=[], label="A") + c = RankTree(children=[], label="C") + ac = RankTree(children=[a, c]) + + expected = collections.defaultdict(collections.Counter) + + a_counter = comb.TopologyCounter.from_sample("A") + b_counter = comb.TopologyCounter.from_sample("B") + ac_counter = comb.TopologyCounter() + ac_counter["A"][a.rank()] = 1 + ac_counter["C"][c.rank()] = 1 + ac_counter["A", "C"][ac.rank()] = 1 + + partial_counter.add_sibling_topologies(a_counter) + expected[("A",)] = collections.Counter({((("A",), (0, 0)),): 1}) + self.assertEqual(partial_counter.partials, expected) + + partial_counter.add_sibling_topologies(a_counter) + expected[("A",)][((("A",), (0, 0)),)] += 1 + self.assertEqual(partial_counter.partials, expected) + + partial_counter.add_sibling_topologies(b_counter) + expected[("B",)][((("B",), (0, 0)),)] = 1 + expected[("A", "B")][((("A",), (0, 0)), (("B",), (0, 0)))] = 2 + self.assertEqual(partial_counter.partials, expected) + + partial_counter.add_sibling_topologies(ac_counter) + expected[("A",)][((("A",), (0, 0)),)] += 1 + expected[("C",)][((("C",), (0, 0)),)] = 1 + expected[("A", "B")][((("A",), (0, 0)), (("B",), (0, 0)))] += 1 + expected[("A", "C")][((("A",), (0, 0)), (("C",), (0, 0)))] = 2 + expected[("A", "C")][((("A", "C"), (0, 0)),)] = 1 + expected[("B", "C")][((("B",), (0, 0)), (("C",), (0, 0)))] = 1 + expected[("A", "B", "C")][ + ((("A",), (0, 0)), (("B",), (0, 0)), (("C",), (0, 0))) + ] = 2 + expected[("A", "B", "C")][((("A", "C"), (0, 0)), (("B",), (0, 0)))] = 1 + self.assertEqual(partial_counter.partials, expected) + + expected_topologies = comb.TopologyCounter() + expected_topologies["A"][(0, 0)] = 3 + expected_topologies["B"][(0, 0)] = 1 + expected_topologies["C"][(0, 0)] = 1 + expected_topologies["A", "B"][(0, 0)] = 3 + expected_topologies["A", "C"][(0, 0)] = 3 + expected_topologies["B", "C"][(0, 0)] = 1 + expected_topologies["A", "B", "C"][(0, 0)] = 2 + expected_topologies["A", "B", "C"][(1, 1)] = 1 + joined_topologies = partial_counter.join_all_combinations() + self.assertEqual(joined_topologies, expected_topologies) + + def test_join_topologies(self): + a = RankTree(children=[], label="A") + b = RankTree(children=[], label="B") + c = RankTree(children=[], label="C") + a_tuple = (("A"), a.rank()) + b_tuple = (("B"), b.rank()) + c_tuple = (("C"), c.rank()) + ab_tuple = (("A", "B"), RankTree(children=[a, b]).rank()) + ac_tuple = (("A", "C"), RankTree(children=[a, c]).rank()) + bc_tuple = (("B", "C"), RankTree(children=[b, c]).rank()) + + self.verify_join_topologies((a_tuple, b_tuple), (0, 0)) + self.verify_join_topologies((b_tuple, a_tuple), (0, 0)) + self.verify_join_topologies((b_tuple, c_tuple), (0, 0)) + + self.verify_join_topologies((a_tuple, b_tuple, c_tuple), (0, 0)) + self.verify_join_topologies((a_tuple, bc_tuple), (1, 0)) + self.verify_join_topologies((b_tuple, ac_tuple), (1, 1)) + self.verify_join_topologies((c_tuple, ab_tuple), (1, 2)) + + def verify_join_topologies(self, topologies, expected_topology): + actual_topology = comb.PartialTopologyCounter.join_topologies(topologies) + self.assertEqual(actual_topology, expected_topology) + + +class TestCountTopologies(unittest.TestCase): + def verify_topologies(self, ts, sample_sets=None, expected=None): + if sample_sets is None: + sample_sets = [ts.samples(population=pop.id) for pop in ts.populations()] + topologies = [t.count_topologies(sample_sets) for t in ts.trees()] + inc_topologies = list(ts.count_topologies(sample_sets)) + # count_topologies calculates the embedded topologies for every + # combination of populations, so we need to check the results + # of subsampling for every combination. + for num_sample_sets in range(1, len(sample_sets) + 1): + for i, t in enumerate(ts.trees()): + just_t = ts.keep_intervals([t.interval], simplify=False) + for sample_set_indexes in itertools.combinations( + range(len(sample_sets)), num_sample_sets + ): + actual_topologies = topologies[i][sample_set_indexes] + actual_inc_topologies = inc_topologies[i][sample_set_indexes] + if len(t.roots) == 1: + subsampled_topologies = self.subsample_topologies( + just_t, sample_sets, sample_set_indexes + ) + self.assertEqual(actual_topologies, subsampled_topologies) + if expected is not None: + self.assertEqual( + actual_topologies, expected[i][sample_set_indexes] + ) + self.assertEqual(actual_topologies, actual_inc_topologies) + + def subsample_topologies(self, ts, sample_sets, sample_set_indexes): + subsample_sets = [sample_sets[i] for i in sample_set_indexes] + topologies = collections.Counter() + for subsample in itertools.product(*subsample_sets): + for pop_tree in ts.simplify(samples=subsample).trees(): + # regions before and after keep interval have all samples as roots + # so don't count those + # The single tree of interest should have one root + if len(pop_tree.roots) == 1: + topologies[pop_tree.rank()] += 1 + return topologies + + def test_single_population(self): + n = 10 + ts = msprime.simulate(n, recombination_rate=10) + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): n}) + self.verify_topologies(ts, expected=[expected] * ts.num_trees) + + def test_three_populations(self): + nodes = io.StringIO( + """\ + id is_sample time population individual metadata + 0 1 0.000000 0 -1 + 1 1 0.000000 1 -1 + 2 1 0.000000 1 -1 + 3 1 0.000000 2 -1 + 4 1 0.000000 2 -1 + 5 1 0.000000 0 -1 + 6 0 1.000000 0 -1 + 7 0 2.000000 0 -1 + 8 0 2.000000 0 -1 + 9 0 3.000000 0 -1 + 10 0 4.000000 0 -1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.000000 1.000000 6 4 + 0.000000 1.000000 6 5 + 0.000000 1.000000 7 1 + 0.000000 1.000000 7 2 + 0.000000 1.000000 8 3 + 0.000000 1.000000 8 6 + 0.000000 1.000000 9 7 + 0.000000 1.000000 9 8 + 0.000000 1.000000 10 0 + 0.000000 1.000000 10 9 + """ + ) + ts = tskit.load_text( + nodes, edges, sequence_length=1, strict=False, base64_metadata=False + ) + + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): 2}) + expected[1] = collections.Counter({(0, 0): 2}) + expected[2] = collections.Counter({(0, 0): 2}) + expected[0, 1] = collections.Counter({(0, 0): 4}) + expected[0, 2] = collections.Counter({(0, 0): 4}) + expected[1, 2] = collections.Counter({(0, 0): 4}) + expected[0, 1, 2] = collections.Counter({(1, 0): 4, (1, 1): 4}) + self.verify_topologies(ts, expected=[expected]) + + def test_multiple_roots(self): + tables = tskit.TableCollection(sequence_length=1.0) + tables.populations.add_row() + tables.populations.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1) + + # Not samples so they are ignored + tables.nodes.add_row(time=1) + tables.nodes.add_row(time=1, population=1) + + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): 1}) + expected[1] = collections.Counter({(0, 0): 1}) + self.verify_topologies(tables.tree_sequence(), expected=[expected]) + + def test_no_sample_subtrees(self): + tables = tskit.TableCollection(sequence_length=1.0) + c1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + c2 = tables.nodes.add_row(time=0) + c3 = tables.nodes.add_row(time=0) + p1 = tables.nodes.add_row(time=1) + p2 = tables.nodes.add_row(time=1) + + tables.edges.add_row(left=0, right=1, parent=p1, child=c2) + tables.edges.add_row(left=0, right=1, parent=p1, child=c3) + tables.edges.add_row(left=0, right=1, parent=p2, child=c1) + + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): 1}) + self.verify_topologies(tables.tree_sequence(), expected=[expected]) + + def test_no_full_topology(self): + tables = tskit.TableCollection(sequence_length=1.0) + tables.populations.add_row() + tables.populations.add_row() + tables.populations.add_row() + child1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0) + child2 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1) + parent = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=1, parent=parent, child=child1) + tables.edges.add_row(left=0, right=1, parent=parent, child=child2) + + # Left as root so there is no topology with all three populations + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=2) + + expected = comb.TopologyCounter() + for pop_combo in [(0,), (1,), (2,), (0, 1)]: + expected[pop_combo] = collections.Counter({(0, 0): 1}) + self.verify_topologies(tables.tree_sequence(), expected=[expected]) + + def test_polytomies(self): + tables = tskit.TableCollection(sequence_length=1.0) + tables.populations.add_row() + tables.populations.add_row() + tables.populations.add_row() + c1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0) + c2 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1) + c3 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=2) + p = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=1, parent=p, child=c1) + tables.edges.add_row(left=0, right=1, parent=p, child=c2) + tables.edges.add_row(left=0, right=1, parent=p, child=c3) + + expected = comb.TopologyCounter() + for pop_combos in [0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)]: + expected[pop_combos] = collections.Counter({(0, 0): 1}) + self.verify_topologies(tables.tree_sequence(), expected=[expected]) + + def test_custom_key(self): + nodes = io.StringIO( + """\ + id is_sample time population individual metadata + 0 1 0.000000 0 -1 + 1 1 0.000000 0 -1 + 2 1 0.000000 0 -1 + 3 1 0.000000 0 -1 + 4 1 0.000000 0 -1 + 5 0 1.000000 0 -1 + 6 0 1.000000 0 -1 + 7 0 2.000000 0 -1 + 8 0 3.000000 0 -1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.000000 1.000000 5 0 + 0.000000 1.000000 5 1 + 0.000000 1.000000 6 2 + 0.000000 1.000000 6 3 + 0.000000 1.000000 7 5 + 0.000000 1.000000 7 6 + 0.000000 1.000000 8 4 + 0.000000 1.000000 8 7 + """ + ) + ts = tskit.load_text( + nodes, edges, sequence_length=1, strict=False, base64_metadata=False + ) + + sample_sets = [[0, 1], [2, 3], [4]] + + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): 2}) + expected[1] = collections.Counter({(0, 0): 2}) + expected[2] = collections.Counter({(0, 0): 1}) + expected[0, 1] = collections.Counter({(0, 0): 4}) + expected[0, 2] = collections.Counter({(0, 0): 2}) + expected[1, 2] = collections.Counter({(0, 0): 2}) + expected[0, 1, 2] = collections.Counter({(1, 2): 4}) + + tree_topologies = ts.first().count_topologies(sample_sets) + treeseq_topologies = list(ts.count_topologies(sample_sets)) + self.assertEqual(tree_topologies, expected) + self.assertEqual(treeseq_topologies, [expected]) + + def test_ignores_non_sample_leaves(self): + nodes = io.StringIO( + """\ + id is_sample time population individual metadata + 0 1 0.000000 0 -1 + 1 0 0.000000 0 -1 + 2 1 0.000000 0 -1 + 3 0 0.000000 0 -1 + 4 1 0.000000 0 -1 + 5 0 1.000000 0 -1 + 6 0 1.000000 0 -1 + 7 0 2.000000 0 -1 + 8 0 3.000000 0 -1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.000000 1.000000 5 0 + 0.000000 1.000000 5 1 + 0.000000 1.000000 6 2 + 0.000000 1.000000 6 3 + 0.000000 1.000000 7 5 + 0.000000 1.000000 7 6 + 0.000000 1.000000 8 4 + 0.000000 1.000000 8 7 + """ + ) + ts = tskit.load_text( + nodes, edges, sequence_length=1, strict=False, base64_metadata=False + ) + + sample_sets = [[0], [2], [4]] + + expected = comb.TopologyCounter() + expected[0] = collections.Counter({(0, 0): 1}) + expected[1] = collections.Counter({(0, 0): 1}) + expected[2] = collections.Counter({(0, 0): 1}) + expected[0, 1] = collections.Counter({(0, 0): 1}) + expected[0, 2] = collections.Counter({(0, 0): 1}) + expected[1, 2] = collections.Counter({(0, 0): 1}) + expected[0, 1, 2] = collections.Counter({(1, 2): 1}) + + tree_topologies = ts.first().count_topologies(sample_sets) + treeseq_topologies = list(ts.count_topologies(sample_sets)) + self.assertEqual(tree_topologies, expected) + self.assertEqual(treeseq_topologies, [expected]) + + def test_internal_samples_errors(self): + tables = tskit.TableCollection(sequence_length=1.0) + + c1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + c2 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + p = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=1) + + tables.edges.add_row(left=0, right=1, parent=p, child=c1) + tables.edges.add_row(left=0, right=1, parent=p, child=c2) + + self.verify_value_error(tables.tree_sequence()) + + def test_non_sample_nodes_errors(self): + tables = tskit.TableCollection(sequence_length=1.0) + + c1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + c2 = tables.nodes.add_row(time=0) + p = tables.nodes.add_row(time=1) + + tables.edges.add_row(left=0, right=1, parent=p, child=c1) + tables.edges.add_row(left=0, right=1, parent=p, child=c2) + + sample_sets = [[0], [1]] + self.verify_value_error(tables.tree_sequence(), sample_sets) + + sample_sets = [[0], [-1]] + self.verify_node_out_of_bounds_error(tables.tree_sequence(), sample_sets) + + def verify_value_error(self, ts, sample_sets=None): + with self.assertRaises(ValueError): + ts.first().count_topologies(sample_sets) + with self.assertRaises(ValueError): + list(ts.count_topologies(sample_sets)) + + def verify_node_out_of_bounds_error(self, ts, sample_sets=None): + with self.assertRaises(ValueError): + ts.first().count_topologies(sample_sets) + with self.assertRaises(IndexError): + list(ts.count_topologies(sample_sets)) + + def test_standard_msprime_migrations(self): + for num_populations in range(2, 5): + samples = [5] * num_populations + ts = self.simulate_multiple_populations(samples) + self.verify_topologies(ts) + + def simulate_multiple_populations(self, sample_sizes): + d = len(sample_sizes) + M = 0.2 + m = M / (2 * (d - 1)) + + migration_matrix = [ + [m if k < d and k == i + 1 else 0 for k in range(d)] for i in range(d) + ] + + pop_configurations = [ + msprime.PopulationConfiguration(sample_size=size) for size in sample_sizes + ] + return msprime.simulate( + population_configurations=pop_configurations, + migration_matrix=migration_matrix, + recombination_rate=0.1, + ) + + def test_msprime_dtwf(self): + migration_matrix = np.zeros((4, 4)) + population_configurations = [ + msprime.PopulationConfiguration( + sample_size=10, initial_size=10, growth_rate=0 + ), + msprime.PopulationConfiguration( + sample_size=10, initial_size=10, growth_rate=0 + ), + msprime.PopulationConfiguration( + sample_size=10, initial_size=10, growth_rate=0 + ), + msprime.PopulationConfiguration( + sample_size=0, initial_size=10, growth_rate=0 + ), + ] + demographic_events = [ + msprime.PopulationParametersChange(population=1, time=0.1, initial_size=5), + msprime.PopulationParametersChange(population=0, time=0.2, initial_size=5), + msprime.MassMigration(time=1.1, source=0, dest=2), + msprime.MassMigration(time=1.2, source=1, dest=3), + msprime.MigrationRateChange(time=2.1, rate=0.3, matrix_index=(2, 3)), + msprime.MigrationRateChange(time=2.2, rate=0.3, matrix_index=(3, 2)), + ] + ts = msprime.simulate( + migration_matrix=migration_matrix, + population_configurations=population_configurations, + demographic_events=demographic_events, + random_seed=2, + model="dtwf", + ) + + self.verify_topologies(ts) + + def test_forward_time_wright_fisher_unsimplified_all_sample_sets(self): + tables = wf.wf_sim( + 4, + 5, + seed=1, + deep_history=False, + initial_generation_samples=False, + num_loci=10, + ) + tables.sort() + ts = tables.tree_sequence() + for S in test_stats.set_partitions(list(ts.samples())): + self.verify_topologies(ts, sample_sets=S) + + def test_forward_time_wright_fisher_unsimplified(self): + tables = wf.wf_sim( + 20, + 15, + seed=1, + deep_history=False, + initial_generation_samples=False, + num_loci=20, + ) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify_topologies(ts, sample_sets=[samples[:10], samples[10:]]) + + def test_forward_time_wright_fisher_simplified(self): + tables = wf.wf_sim( + 30, + 10, + seed=1, + deep_history=False, + initial_generation_samples=False, + num_loci=5, + ) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify_topologies(ts, sample_sets=[samples[:10], samples[10:]]) diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index b539af628d..dd3d68e6a7 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -56,6 +56,7 @@ all_trees, all_tree_shapes, all_tree_labellings, + TopologyCounter, ) from tskit.exceptions import * # NOQA from tskit.util import * # NOQA diff --git a/python/tskit/combinatorics.py b/python/tskit/combinatorics.py index 0b0c050759..67dcaca33b 100644 --- a/python/tskit/combinatorics.py +++ b/python/tskit/combinatorics.py @@ -24,13 +24,237 @@ Module for ranking and unranking trees. Trees are considered only leaf-labelled and unordered, so order of children does not influence equality. """ +import collections +import functools import heapq import itertools -from functools import lru_cache + +import numpy as np import tskit +def treeseq_count_topologies(ts, sample_sets): + topology_counter = np.full(ts.num_nodes, None, dtype=object) + parent = np.full(ts.num_nodes, -1) + + def update_state(tree, u): + stack = [u] + while len(stack) > 0: + v = stack.pop() + children = [] + for c in tree.children(v): + if topology_counter[c] is not None: + children.append(topology_counter[c]) + if len(children) > 0: + topology_counter[v] = combine_child_topologies(children) + else: + topology_counter[v] = None + p = parent[v] + if p != -1: + stack.append(p) + + for sample_set_index, sample_set in enumerate(sample_sets): + for u in sample_set: + if not ts.node(u).is_sample(): + raise ValueError(f"Node {u} in sample_sets is not a sample.") + topology_counter[u] = TopologyCounter.from_sample(sample_set_index) + + for tree, (_, edges_out, edges_in) in zip(ts.trees(), ts.edge_diffs()): + # Avoid recomputing anything for the parent until all child edges + # for that parent are inserted/removed + for p, sibling_edges in itertools.groupby(edges_out, key=lambda e: e.parent): + for e in sibling_edges: + parent[e.child] = -1 + update_state(tree, p) + for p, sibling_edges in itertools.groupby(edges_in, key=lambda e: e.parent): + if tree.is_sample(p): + raise ValueError("Internal samples not supported.") + for e in sibling_edges: + parent[e.child] = p + update_state(tree, p) + + counters = [] + for root in tree.roots: + if topology_counter[root] is not None: + counters.append(topology_counter[root]) + yield TopologyCounter.merge(counters) + + +def tree_count_topologies(tree, sample_sets): + for u in tree.samples(): + if not tree.is_leaf(u): + raise ValueError("Internal samples not supported.") + + topology_counter = np.full(tree.tree_sequence.num_nodes, None, dtype=object) + for sample_set_index, sample_set in enumerate(sample_sets): + for u in sample_set: + if not tree.is_sample(u): + raise ValueError(f"Node {u} in sample_sets is not a sample.") + topology_counter[u] = TopologyCounter.from_sample(sample_set_index) + + for u in tree.nodes(order="postorder"): + children = [] + for v in tree.children(u): + if topology_counter[v] is not None: + children.append(topology_counter[v]) + if len(children) > 0: + topology_counter[u] = combine_child_topologies(children) + + counters = [] + for root in tree.roots: + if topology_counter[root] is not None: + counters.append(topology_counter[root]) + return TopologyCounter.merge(counters) + + +def combine_child_topologies(topology_counters): + """ + Select all combinations of topologies from different + counters in ``topology_counters`` that are capable of + being combined into a single topology. This includes + any combination of at least two topologies, all from + different children, where no topologies share a + sample set index. + """ + partial_topologies = PartialTopologyCounter() + for tc in topology_counters: + partial_topologies.add_sibling_topologies(tc) + + return partial_topologies.join_all_combinations() + + +class TopologyCounter: + """ + Contains the distributions of embedded topologies for every combination + of the sample sets used to generate the ``TopologyCounter``. It is + indexable by a combination of sample set indexes and returns a + ``collections.Counter`` whose keys are topology ranks + (see :ref:`sec_tree_ranks`). See :meth:`Tree.count_topologies` for more + detail on how this structure is used. + """ + + def __init__(self): + self.topologies = collections.defaultdict(collections.Counter) + + def __getitem__(self, sample_set_indexes): + k = TopologyCounter._to_key(sample_set_indexes) + return self.topologies[k] + + def __setitem__(self, sample_set_indexes, counter): + k = TopologyCounter._to_key(sample_set_indexes) + self.topologies[k] = counter + + @staticmethod + def _to_key(sample_set_indexes): + if not isinstance(sample_set_indexes, collections.Iterable): + sample_set_indexes = (sample_set_indexes,) + return tuple(sorted(sample_set_indexes)) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.topologies == other.topologies + + @staticmethod + def merge(topology_counters): + """ + Union together independent topology counters into one. + """ + total = TopologyCounter() + for tc in topology_counters: + for k, v in tc.topologies.items(): + total.topologies[k] += v + + return total + + @staticmethod + def from_sample(sample_set_index): + """ + Generate the topologies covered by a single sample. This + is the single-leaf topology representing the single sample + set. + """ + rank_tree = RankTree(children=[], label=sample_set_index) + tc = TopologyCounter() + tc[sample_set_index][rank_tree.rank()] = 1 + return tc + + +class PartialTopologyCounter: + """ + Represents the possible combinations of children under a node in a tree + and the combinations of embedded topologies that are rooted at the node. + This allows an efficient way of calculating which unique embedded + topologies arise by only every storing a given pairing of sibling topologies + once. + ``partials`` is a dictionary where a key is a tuple of sample set indexes, + and the value is a ``collections.Counter`` that counts combinations of + sibling topologies whose tips represent the sample sets in the key. + Each element of the counter is a homogeneous tuple where each element represents + a topology. The topology is itself a tuple of the sample set indexes in that + topology and the rank. + """ + + def __init__(self): + self.partials = collections.defaultdict(collections.Counter) + + def add_sibling_topologies(self, topology_counter): + """ + Combine each topology in the given TopologyCounter with every existing + combination of topologies whose sample set indexes are disjoint from the + topology from the counter. This also includes adding the topologies from + the counter without joining them to any existing combinations. + """ + merged = collections.defaultdict(collections.Counter) + for sample_set_indexes, topologies in topology_counter.topologies.items(): + for rank, count in topologies.items(): + topology = ((sample_set_indexes, rank),) + # Cross with existing topology combinations + for sibling_sample_set_indexes, siblings in self.partials.items(): + if isdisjoint(sample_set_indexes, sibling_sample_set_indexes): + for sib_topologies, sib_count in siblings.items(): + merged_topologies = merge_tuple(sib_topologies, topology) + merged_sample_set_indexes = merge_tuple( + sibling_sample_set_indexes, sample_set_indexes + ) + merged[merged_sample_set_indexes][merged_topologies] += ( + count * sib_count + ) + # Propagate without combining + merged[sample_set_indexes][topology] += count + + for sample_set_indexes, counter in merged.items(): + self.partials[sample_set_indexes] += counter + + def join_all_combinations(self): + """ + For each pairing of child topologies, join them together into a new + tree and count the resulting topologies. + """ + topology_counter = TopologyCounter() + for sample_set_indexes, sibling_topologies in self.partials.items(): + for topologies, count in sibling_topologies.items(): + # A node must have at least two children + if len(topologies) >= 2: + rank = PartialTopologyCounter.join_topologies(topologies) + topology_counter[sample_set_indexes][rank] += count + else: + # Pass on the single tree without adding a parent + for _, rank in topologies: + topology_counter[sample_set_indexes][rank] += count + + return topology_counter + + @staticmethod + def join_topologies(child_topologies): + children = [] + for sample_set_indexes, rank in child_topologies: + n = len(sample_set_indexes) + t = RankTree.unrank(rank, n, list(sample_set_indexes)) + children.append(t) + children.sort(key=RankTree.canonical_order) + return RankTree(children).rank() + + def all_trees(num_leaves): """ Generates all unique leaf-labelled trees with ``num_leaves`` @@ -205,12 +429,17 @@ def label_rank(self): return self._label_rank @staticmethod - def unrank(rank, num_leaves): + def unrank(rank, num_leaves, labels=None): + """ + Produce a ``RankTree`` of the given ``rank`` with ``num_leaves`` leaves, + labelled with ``labels``. Labels must be sorted, and if ``None`` default + to ``[0, num_leaves)``. + """ shape_rank, label_rank = rank if shape_rank < 0 or label_rank < 0: raise ValueError("Rank is out of bounds.") unlabelled = RankTree.shape_unrank(shape_rank, num_leaves) - return unlabelled.label_unrank(label_rank) + return unlabelled.label_unrank(label_rank, labels) @staticmethod def shape_unrank(shape_rank, n): @@ -230,7 +459,7 @@ def shape_unrank(shape_rank, n): def label_unrank(self, label_rank, labels=None): """ Generate a tree with the same shape, whose leaves are labelled - from `labels` with the labelling corresponding to `label_rank`. + from ``labels`` with the labelling corresponding to ``label_rank``. """ if labels is None: labels = list(range(self.num_leaves)) @@ -268,6 +497,9 @@ def from_tsk_tree_node(tree, u): if tree.is_leaf(u): return RankTree(children=[], label=u) + if tree.num_children(u) == 1: + raise ValueError("Cannot rank trees with unary nodes") + children = list( sorted( (RankTree.from_tsk_tree_node(tree, c) for c in tree.children(u)), @@ -279,11 +511,14 @@ def from_tsk_tree_node(tree, u): @staticmethod def from_tsk_tree(tree): if tree.num_roots != 1: - raise ValueError("Can't rank trees with multiple roots") + raise ValueError("Cannot rank trees with multiple roots") return RankTree.from_tsk_tree_node(tree, tree.root) def to_tsk_tree(self): + if set(self.labels) != set(range(self.num_leaves)): + raise ValueError("Labels set must be equivalent to [0, num_leaves)") + seq_length = 1 tables = tskit.TableCollection(seq_length) @@ -480,7 +715,7 @@ def is_symmetrical(self): # so we should compute a vector of those results up front instead of using # repeated calls to this function. # Put an lru_cache on for now as a quick replacement (cuts test time down by 80%) -@lru_cache() +@functools.lru_cache() def num_shapes(n): """ The cardinality of the set of unlabelled trees with n leaves, @@ -860,3 +1095,11 @@ def group_by(values, equal): def group_partition(part): return group_by(part, lambda x, y: x == y) + + +def merge_tuple(tup1, tup2): + return tuple(heapq.merge(tup1, tup2)) + + +def isdisjoint(iterable1, iterable2): + return set(iterable1).isdisjoint(iterable2) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7d4314f5e1..6f6135e5b5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -833,6 +833,57 @@ def unrank(rank, num_leaves): """ return combinatorics.RankTree.unrank(rank, num_leaves).to_tsk_tree() + def count_topologies(self, sample_sets=None): + """ + Calculates the distribution of embedded topologies for every combination + of the sample sets in ``sample_sets``. ``sample_sets`` defaults to all + samples in the tree grouped by population. + + ``sample_sets`` need not include all samples but must be pairwise disjoint. + + The returned object is a :class:`tskit.TopologyCounter` that contains + counts of topologies per combination of sample sets. For example, + + >>> topology_counter = tree.count_topologies() + >>> rank, count = topology_counter[0, 1, 2].most_common(1)[0] + + produces the most common tree topology, with populations 0, 1 + and 2 as its tips, according to the genealogies of those + populations' samples in this tree. + + The counts for each topology in the :class:`tskit.TopologyCounter` + are absolute counts that we would get if we were to select all + combinations of samples from the relevant sample sets. + For sample sets :math:`[s_0, s_1, ..., s_n]`, the total number of + topologies for those sample sets is equal to + :math:`|s_0| * |s_1| * ... * |s_n|`, so the counts in the counter + ``topology_counter[0, 1, ..., n]`` should sum to + :math:`|s_0| * |s_1| * ... * |s_n|`. + + To convert the topology counts to probabilities, divide by the total + possible number of sample combinations from the sample sets in question:: + + >>> set_sizes = [len(sample_set) for sample_set in sample_sets] + >>> p = count / (set_sizes[0] * set_sizes[1] * set_sizes[2]) + + .. warning:: The interface for this method is preliminary and may be subject to + backwards incompatible changes in the near future. + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + Defaults to all samples grouped by population. + :rtype: tskit.TopologyCounter + :raises ValueError: If nodes in ``sample_sets`` are invalid or are + internal samples. + """ + if sample_sets is None: + sample_sets = [ + self.tree_sequence.samples(population=pop.id) + for pop in self.tree_sequence.populations() + ] + + return combinatorics.tree_count_topologies(self, sample_sets) + def get_branch_length(self, u): # Deprecated alias for branch_length return self.branch_length(u) @@ -5840,6 +5891,29 @@ def kc_distance(self, other, lambda_=0.0): """ return self._ll_tree_sequence.get_kc_distance(other._ll_tree_sequence, lambda_) + def count_topologies(self, sample_sets=None): + """ + Returns a generator that produces the same distribution of topologies as + :meth:`Tree.count_topologies` but sequentially for every tree in a tree + sequence. For use on a tree sequence this method is much faster than + computing the result independently per tree. + + .. warning:: The interface for this method is preliminary and may be subject to + backwards incompatible changes in the near future. + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of individuals to compute the statistic with. + :rtype: iter(:class:`tskit.TopologyCounter`) + :raises ValueError: If nodes in ``sample_sets`` are invalid or are + internal samples. + """ + if sample_sets is None: + sample_sets = [ + self.samples(population=pop.id) for pop in self.populations() + ] + + yield from combinatorics.treeseq_count_topologies(self, sample_sets) + ############################################ # # Deprecated APIs. These are either already unsupported, or will be unsupported in a