Skip to content

Commit d0d4242

Browse files
Implement algorithm for calculating species tree topology distributions.
1 parent 1a1f315 commit d0d4242

File tree

2 files changed

+314
-2
lines changed

2 files changed

+314
-2
lines changed

python/tests/test_combinatorics.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
"""
2424
Test cases for combinatorial algorithms.
2525
"""
26+
import io
2627
import itertools
2728
import unittest
29+
from collections import Counter
30+
from collections import defaultdict
31+
32+
import msprime
2833

2934
import tskit
3035
import tskit.combinatorics as comb
@@ -413,3 +418,166 @@ def test_is_symmetrical(self):
413418
self.assertFalse(three_leaf_asym.is_symmetrical())
414419
six_leaf_sym = RankTree(children=[three_leaf_asym, three_leaf_asym])
415420
self.assertTrue(six_leaf_sym.is_symmetrical())
421+
422+
423+
class TestCountTopologies(unittest.TestCase):
424+
def verify_topologies(self, ts, expected=None):
425+
populations = [pop.id for pop in ts.populations()]
426+
topologies = [comb.count_topologies(t) for t in ts.trees()]
427+
inc_topologies = list(comb.count_topologies_incremental(ts))
428+
for num_pops in range(1, ts.num_populations + 1):
429+
for i, t in enumerate(ts.trees()):
430+
just_t = ts.keep_intervals([t.interval])
431+
for pops in itertools.combinations(populations, num_pops):
432+
actual_topologies = topologies[i][frozenset(pops)]
433+
actual_inc_topologies = inc_topologies[i][frozenset(pops)]
434+
if len(t.roots) == 1:
435+
subsampled_topologies = self.subsample_topologies(just_t, pops)
436+
self.assertEqual(actual_topologies, subsampled_topologies)
437+
if expected is not None:
438+
self.assertEqual(
439+
actual_topologies, expected[i][frozenset(pops)]
440+
)
441+
self.assertEqual(actual_topologies, actual_inc_topologies)
442+
443+
def subsample_topologies(self, ts, populations):
444+
samples_per_pop = [ts.samples(population=p) for p in populations]
445+
topologies = Counter()
446+
for subsample in itertools.product(*samples_per_pop):
447+
for pop_tree in ts.simplify(samples=subsample).trees():
448+
# regions before and after keep interval have all samples as roots
449+
# so don't count those
450+
# The single tree of interest should have one root
451+
if len(pop_tree.roots) == 1:
452+
topologies[pop_tree.rank()] += 1
453+
return topologies
454+
455+
def test_single_population(self):
456+
n = 10
457+
ts = msprime.simulate(n, recombination_rate=10)
458+
expected = defaultdict(Counter)
459+
expected[frozenset([0])] = Counter({(0, 0): n})
460+
self.verify_topologies(ts, [expected] * ts.num_trees)
461+
462+
def test_three_populations(self):
463+
nodes = io.StringIO(
464+
"""\
465+
id is_sample time population individual metadata
466+
0 1 0.000000 0 -1
467+
1 1 0.000000 1 -1
468+
2 1 0.000000 1 -1
469+
3 1 0.000000 2 -1
470+
4 1 0.000000 2 -1
471+
5 1 0.000000 0 -1
472+
6 0 1.000000 0 -1
473+
7 0 2.000000 0 -1
474+
8 0 2.000000 0 -1
475+
9 0 3.000000 0 -1
476+
10 0 4.000000 0 -1
477+
"""
478+
)
479+
edges = io.StringIO(
480+
"""\
481+
left right parent child
482+
0.000000 1.000000 6 4
483+
0.000000 1.000000 6 5
484+
0.000000 1.000000 7 1
485+
0.000000 1.000000 7 2
486+
0.000000 1.000000 8 3
487+
0.000000 1.000000 8 6
488+
0.000000 1.000000 9 7
489+
0.000000 1.000000 9 8
490+
0.000000 1.000000 10 0
491+
0.000000 1.000000 10 9
492+
"""
493+
)
494+
ts = tskit.load_text(
495+
nodes, edges, sequence_length=1, strict=False, base64_metadata=False
496+
)
497+
498+
expected = defaultdict(Counter)
499+
expected[frozenset([0])] = Counter({(0, 0): 2})
500+
expected[frozenset([1])] = Counter({(0, 0): 2})
501+
expected[frozenset([2])] = Counter({(0, 0): 2})
502+
expected[frozenset([0, 1])] = Counter({(0, 0): 4})
503+
expected[frozenset([0, 2])] = Counter({(0, 0): 4})
504+
expected[frozenset([1, 2])] = Counter({(0, 0): 4})
505+
expected[frozenset([0, 1, 2])] = Counter({(1, 0): 4, (1, 1): 4})
506+
self.verify_topologies(ts, [expected])
507+
508+
def test_multiple_roots(self):
509+
tables = tskit.TableCollection(sequence_length=1.0)
510+
tables.populations.add_row()
511+
tables.populations.add_row()
512+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0)
513+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1)
514+
515+
# Not samples so they are ignored
516+
tables.nodes.add_row(time=1)
517+
tables.nodes.add_row(time=1, population=1)
518+
519+
expected = defaultdict(Counter)
520+
expected[frozenset([0])] = Counter({(0, 0): 1})
521+
expected[frozenset([1])] = Counter({(0, 0): 1})
522+
self.verify_topologies(tables.tree_sequence(), [expected])
523+
524+
def test_no_full_topology(self):
525+
tables = tskit.TableCollection(sequence_length=1.0)
526+
tables.populations.add_row()
527+
tables.populations.add_row()
528+
tables.populations.add_row()
529+
child1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0)
530+
child2 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1)
531+
parent = tables.nodes.add_row(time=1)
532+
tables.edges.add_row(left=0, right=1, parent=parent, child=child1)
533+
tables.edges.add_row(left=0, right=1, parent=parent, child=child2)
534+
535+
# Left as root so there is no topology with all three populations
536+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=2)
537+
538+
expected = defaultdict(Counter)
539+
for pop_combo in [[0], [1], [2], [0, 1]]:
540+
expected[frozenset(pop_combo)] = Counter({(0, 0): 1})
541+
self.verify_topologies(tables.tree_sequence(), [expected])
542+
543+
def test_polytomies(self):
544+
tables = tskit.TableCollection(sequence_length=1.0)
545+
tables.populations.add_row()
546+
tables.populations.add_row()
547+
tables.populations.add_row()
548+
c1 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=0)
549+
c2 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=1)
550+
c3 = tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, population=2)
551+
p = tables.nodes.add_row(time=1)
552+
tables.edges.add_row(left=0, right=1, parent=p, child=c1)
553+
tables.edges.add_row(left=0, right=1, parent=p, child=c2)
554+
tables.edges.add_row(left=0, right=1, parent=p, child=c3)
555+
556+
expected = defaultdict(Counter)
557+
for pop_combos in [[0], [1], [2], [0, 1], [0, 2], [1, 2], [0, 1, 2]]:
558+
expected[frozenset(pop_combos)] = Counter({(0, 0): 1})
559+
self.verify_topologies(tables.tree_sequence(), [expected])
560+
561+
def test_msprime_migrations(self):
562+
for num_populations in range(2, 5):
563+
samples = [5] * num_populations
564+
ts = self.simulate_multiple_populations(samples)
565+
self.verify_topologies(ts)
566+
567+
def simulate_multiple_populations(self, sample_sizes):
568+
d = len(sample_sizes)
569+
M = 0.2
570+
m = M / (2 * (d - 1))
571+
572+
migration_matrix = [
573+
[m if k < d and k == i + 1 else 0 for k in range(d)] for i in range(d)
574+
]
575+
576+
pop_configurations = [
577+
msprime.PopulationConfiguration(sample_size=size) for size in sample_sizes
578+
]
579+
return msprime.simulate(
580+
population_configurations=pop_configurations,
581+
migration_matrix=migration_matrix,
582+
recombination_rate=0.05,
583+
)

python/tskit/combinatorics.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,155 @@
2626
"""
2727
import heapq
2828
import itertools
29+
from collections import Counter
30+
from collections import defaultdict
2931
from functools import lru_cache
3032

33+
import numpy as np
34+
3135
import tskit
3236

3337

38+
def count_topologies_incremental(ts, key=None):
39+
def get_population(u):
40+
return ts.node(u).population
41+
42+
if key is None:
43+
key = get_population
44+
45+
topology_tree = [None for _ in range(ts.num_nodes)]
46+
parent = np.full(ts.num_nodes, -1)
47+
48+
def update_state(tree, u):
49+
stack = [u]
50+
while len(stack) > 0:
51+
v = stack.pop()
52+
children = []
53+
c = tree.left_child(v)
54+
while c != tskit.NULL:
55+
if topology_tree[c] is not None:
56+
children.append(topology_tree[c])
57+
c = tree.right_sib(c)
58+
if len(children) > 0:
59+
topology_tree[v] = TopologyTree(children)
60+
else:
61+
topology_tree[v] = None
62+
p = parent[v]
63+
if p != -1:
64+
stack.append(p)
65+
66+
for u in ts.samples():
67+
topology_tree[u] = TopologyTree(children=[], label=key(u))
68+
69+
for tree, (_, edges_out, edges_in) in zip(ts.trees(), ts.edge_diffs()):
70+
# Avoid recomputing anything for the parent until all child edges
71+
# for that parent are inserted/removed
72+
for p, sibling_edges in itertools.groupby(edges_out, key=lambda e: e.parent):
73+
for e in sibling_edges:
74+
parent[e.child] = -1
75+
update_state(tree, p)
76+
for p, sibling_edges in itertools.groupby(edges_in, key=lambda e: e.parent):
77+
for e in sibling_edges:
78+
parent[e.child] = p
79+
update_state(tree, p)
80+
81+
all_topologies = defaultdict(Counter)
82+
for root in tree.roots:
83+
# None if there are no samples under the root
84+
if topology_tree[root] is not None:
85+
for labels, counter in topology_tree[root].topologies.items():
86+
all_topologies[labels] += counter
87+
yield all_topologies
88+
89+
90+
def count_topologies(tree, key=None):
91+
def pop_key(u):
92+
return tree.tree_sequence.node(u).population
93+
94+
if key is None:
95+
key = pop_key
96+
97+
all_topologies = defaultdict(Counter)
98+
for root in tree.roots:
99+
topo_tree = TopologyTree.from_tsk_tree(tree, root, key)
100+
for labels, counter in topo_tree.topologies.items():
101+
all_topologies[labels] += counter
102+
103+
return all_topologies
104+
105+
106+
class TopologyTree:
107+
def __init__(self, children, label=None):
108+
self.children = children
109+
if len(children) == 0:
110+
assert label is not None
111+
self.label = label
112+
113+
self.topologies = self._topologies()
114+
115+
def _topologies(self):
116+
if self.is_leaf():
117+
rank_tree = RankTree(children=[], label=self.label)
118+
labels = frozenset([self.label])
119+
return {labels: Counter({rank_tree.rank(): 1})}
120+
121+
# Create partial trees (incomplete sets of child nodes)
122+
partials = defaultdict(Counter)
123+
for child in self.children:
124+
merged = TopologyTree.merge_child_topologies(partials, child)
125+
for labels, counter in merged.items():
126+
partials[labels] += counter
127+
128+
# Join sets of child nodes together and add to existing topologies
129+
topologies = defaultdict(Counter)
130+
for root_labels, sibling_topologies in partials.items():
131+
for k, count in sibling_topologies.items():
132+
# A node must have at least two children
133+
if len(k) >= 2:
134+
children = []
135+
for labels, rank in k:
136+
labels_list = list(sorted(labels))
137+
n = len(labels_list)
138+
children.append(RankTree.unrank(rank, n, labels_list))
139+
children.sort(key=RankTree.canonical_order)
140+
tree = RankTree(children)
141+
topologies[root_labels][tree.rank()] += count
142+
else:
143+
# Pass on the single tree without adding a parent
144+
for _, rank in k:
145+
topologies[root_labels][rank] += count
146+
147+
return topologies
148+
149+
@staticmethod
150+
def merge_child_topologies(partials, child):
151+
merged = defaultdict(Counter)
152+
for labels, topologies in child.topologies.items():
153+
for rank, count in topologies.items():
154+
topology = frozenset([(labels, rank)])
155+
for sibling_labels, children in partials.items():
156+
if labels.isdisjoint(sibling_labels):
157+
for sib_topologies, sib_count in children.items():
158+
merged_topologies = sib_topologies.union(topology)
159+
merged_labels = sibling_labels.union(labels)
160+
merged[merged_labels][merged_topologies] += (
161+
count * sib_count
162+
)
163+
merged[labels][topology] += count
164+
return merged
165+
166+
def is_leaf(self):
167+
return len(self.children) == 0
168+
169+
@staticmethod
170+
def from_tsk_tree(tree, u, key=None):
171+
if tree.is_leaf(u):
172+
return TopologyTree([], label=key(u))
173+
174+
children = [TopologyTree.from_tsk_tree(tree, c, key) for c in tree.children(u)]
175+
return TopologyTree(children=children)
176+
177+
34178
def all_trees(num_leaves):
35179
"""
36180
Generates all unique leaf-labelled trees with ``num_leaves``
@@ -210,12 +354,12 @@ def label_rank(self):
210354
return self._label_rank
211355

212356
@staticmethod
213-
def unrank(rank, num_leaves):
357+
def unrank(rank, num_leaves, labels=None):
214358
shape_rank, label_rank = rank
215359
if shape_rank < 0 or label_rank < 0:
216360
raise ValueError("Rank is out of bounds.")
217361
unlabelled = RankTree.shape_unrank(shape_rank, num_leaves)
218-
return unlabelled.label_unrank(label_rank)
362+
return unlabelled.label_unrank(label_rank, labels)
219363

220364
@staticmethod
221365
def shape_unrank(shape_rank, n):

0 commit comments

Comments
 (0)