-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmonte_carlo_entropy.py
376 lines (309 loc) · 11.7 KB
/
monte_carlo_entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
"""
nwk, fasta, out_nwk, num_iter = sys.argv[1:5]
try: existing_tree = sys.argv[5]
except: existing_tree = False
"""
import dendropy
import copy
from tqdm import tqdm
import numpy as np
import itertools
import sys
from matplotlib import pyplot as plt
from utils.seq_utils import create_seqs_dict
sys.setrecursionlimit(10000)
LABEL_COUNT = 0
SEQ_LEN = 0
def compute_initial_tree_entropy(tree, counts):
"""
This function computes entropy of every cluster (node) in the tree.
It is designed to run be used once to obtain initial entropies each cluster.
"""
tree_entropy = 0
for cluster in tqdm(tree.nodes(), leave=False):
clust_entropy = compute_cluster_entropy(cluster, counts)
tree_entropy += clust_entropy * counts[cluster.label]["size"]
counts[cluster.label]["entropy"] = clust_entropy
return -1 * tree_entropy
def compute_tree_entropy(tree, counts):
tree_entropy = 0
for node in tree.nodes():
entropy = counts[node.label]["entropy"]
size = counts[node.label]["size"]
tree_entropy += entropy * size
return -1 * tree_entropy
def compute_tree_entropy_dividebyparent(tree, counts):
tree_entropy = 0
for cluster in tree.nodes():
if cluster == tree.seed_node:
continue
entropy = counts[cluster.label]["entropy"]
size = counts[cluster.label]["size"]
parent = cluster.parent_node
parent_size = counts[parent.label]["size"]
tree_entropy += (entropy * size) / parent_size
tree_entropy += entropy
return -1 * tree_entropy
def compute_cluster_entropy(cluster, counts):
col_entropy = 0
size = counts[cluster.label]["size"]
for i in range(4): # {A,C,T,G}
for pos in range(SEQ_LEN):
count = counts[cluster.label]["counts"][pos, i] # of A's @ pos i
p_i = count / size
# assert 0 <= p_i <= 1, f"ERROR: mc encountered {p_i=}, {count=}, {pos=}, {SEQ_LEN=}"
entropy = p_i * np.log2(p_i) if p_i != 0 else 0
col_entropy += entropy
return col_entropy / SEQ_LEN
# This does not produce desired results.
# def fast_cluster_entropy(cluster, counts):
# size = counts[cluster.label]["size"]
# count = counts[cluster.label]["counts"]/size
# entropy = count * np.log2(count, where=count != 0)
# return -1 * np.sum(entropy) / SEQ_LEN
def create_counts_matrices(tree, seqs_m, seqs_index, SEQ_LEN):
counts = {}
for i in tree.nodes():
leaves = [k.taxon.label for k in i.leaf_iter()]
count = np.zeros(shape=(SEQ_LEN, 5))
for leaf in leaves:
idx = seqs_index[leaf.replace(" ", "_")]
seq = seqs_m[idx]
count += seq
counts[i.label] = {"counts": count, "size": len(leaves), "entropy": 0}
return counts
def update_counts(counts, i, i_parent, j_parent, lca):
"""
After a move, this function is called to update
1. the nucleotide counts
2. the entropy
of all affected ancestor clusters; those until the LCA of i and j
Params:
i: node i being moved
i_parent: grandparent of i
j_parent: destination parent
"""
p = i_parent
while p and p != lca:
counts[p.label]["counts"] -= counts[i.label]["counts"]
assert all(j >= 0 for i in counts[p.label]['counts'] for j in i)
counts[p.label]["size"] -= counts[i.label]["size"]
counts[p.label]["entropy"] = compute_cluster_entropy(p, counts)
p = p.parent_node
p = j_parent
while p and p != lca:
counts[p.label]["counts"] += counts[i.label]["counts"]
assert all(j >= 0 for i in counts[p.label]['counts'] for j in i), "Error in update counts adding"
counts[p.label]["size"] += counts[i.label]["size"]
counts[p.label]["entropy"] = compute_cluster_entropy(p, counts)
p = p.parent_node
return counts
def update_counts_after_reroot(counts, i, j_parent, new_root):
p = j_parent
while p:
counts[p.label]["counts"] += counts[i.label]["counts"]
counts[p.label]["size"] += counts[i.label]["size"]
counts[p.label]["entropy"] = compute_cluster_entropy(p, counts)
if p == new_root:
break
p = p.parent_node
return counts
def move(tree, node_1, node_2, counts):
n1_parent= node_1.parent_node
n1_sibling = node_1.sibling_nodes()[0]
n2_parent = node_2.parent_node
tree.prune_subtree(node_1, suppress_unifurcations=False)
tree.prune_subtree(node_2, suppress_unifurcations=False)
new_internal = create_node()
insert_internal_node(new_internal, parent=n2_parent, children=[node_1, node_2])
if reroot:=(n1_parent == tree.seed_node):
tree.reroot_at_node(n1_sibling, suppress_unifurcations=False)
tree.prune_subtree(n1_parent, suppress_unifurcations=False)
else:
n1_grandparent = n1_parent.parent_node
tree.prune_subtree(n1_sibling, suppress_unifurcations=False)
tree.prune_subtree(n1_parent, suppress_unifurcations=False)
n1_grandparent.add_child(n1_sibling)
counts = add_counts_entry(new_internal, node_1, node_2, counts)
return counts, reroot
def find_valid_move(tree):
non_root_nodes = tree.nodes(lambda x: x != tree.seed_node)
i = np.random.choice(non_root_nodes)
j = np.random.choice(non_root_nodes)
keep_searching = len(tree.nodes()) * 10
while (
j == i.parent_node
or j.parent_node == i.parent_node
or j in i.preorder_iter()
) and keep_searching:
i = np.random.choice(non_root_nodes)
j = np.random.choice(non_root_nodes)
keep_searching -= 1
if not keep_searching:
found = False
else:
found = True
return i, j, found
def add_counts_entry(new_internal, node_1, node_2, counts):
counts[new_internal.label] = {
"counts": counts[node_1.label]["counts"] + counts[node_2.label]["counts"],
"size": counts[node_1.label]["size"] + counts[node_2.label]["size"],
}
counts[new_internal.label]["entropy"] = compute_cluster_entropy(new_internal, counts)
return counts
def create_node():
global LABEL_COUNT
nd = dendropy.Node()
nd._set_label(str(LABEL_COUNT))
LABEL_COUNT += 1
return nd
def insert_internal_node(node, parent, children):
node.set_child_nodes(children)
parent.add_child(node)
def create_seqs_matrix(seqs, SEQ_LEN):
"""
ONE-HOT ENCODING
seqs is a dict mapping id:sequence
return N x M x 5 tensor where each row is a M x 5 one-hot encoding of an M-length sequence
"""
seqs_m = np.zeros(shape=(len(seqs), SEQ_LEN, 5))
seqs_index = {}
for i, seq_id in enumerate(seqs):
seqs_index[seq_id] = i
for j, nucl in enumerate(seqs[seq_id]):
if nucl == "A":
seqs_m[i][j][0] = 1
elif nucl == "C":
seqs_m[i][j][1] = 1
elif nucl == "T":
seqs_m[i][j][2] = 1
elif nucl == "G":
seqs_m[i][j][3] = 1
else:
assert nucl == "-" or nucl == "N", f"nucl: {nucl}"
seqs_m[i][j][4] = 1
return seqs_m, seqs_index
def get_lca(tree, i, j):
"""
returns lowest common ancestor of nodes i and j
"""
i_leaves = [k.taxon.label for k in i.leaf_iter()]
j_leaves = [k.taxon.label for k in j.leaf_iter()]
leaf_taxa = set([taxa for taxa in itertools.chain(i_leaves, j_leaves)])
lca = tree.mrca(taxon_labels=(leaf_taxa))
return lca
def assign_internal_node_labels(tree):
label = 0
for i in tree.nodes():
i._set_label(str(label))
label += 1
return label
def finish_node(node):
if node.taxon is None:
return node
curr_label = node.taxon.label
try:
taxon, label = curr_label.split(" ")
except ValueError:
taxon, label = curr_label.split("_")
node.taxon.label = taxon
node.label = label
return node
def read_args():
try:
return sys.argv[5], sys.argv[2:5]
except IndexError:
return sys.argv[1:5]
def tree_copy(tree):
tree_str = tree.as_string(
schema="newick",
suppress_leaf_node_labels=False,
suppress_internal_node_labels=False,
suppress_edge_lengths=True,
suppress_rooting=True
)
tree_new = dendropy.Tree.get(
data=tree_str,
schema="newick",
suppress_edge_lengths=True,
preserve_underscores=True,
finish_node_fn=finish_node,
rooting='force-rooted'
)
return tree_new
def read_tree(path):
tree = dendropy.Tree.get(
path=path,
schema="newick",
suppress_edge_lengths=True,
preserve_underscores=True,
rooting='force-rooted')
return tree
def main():
global LABEL_COUNT
global SEQ_LEN
nwk, fasta, out_nwk, num_iter = read_args()
tree_0 = read_tree(nwk)
LABEL_COUNT = assign_internal_node_labels(tree_0)
seqs_d = create_seqs_dict(fasta)
SEQ_LEN = len(list(seqs_d.values())[0])
seqs_m, seqs_index = create_seqs_matrix(seqs_d, SEQ_LEN)
print('Creating counts matrices...')
counts = create_counts_matrices(tree_0, seqs_m, seqs_index, SEQ_LEN)
print('Computing starting entropy...')
starting_entropy = compute_initial_tree_entropy(tree_0, counts)
current_entropy = starting_entropy
total_moves = 0
print('Running moves experiment...')
for k in tqdm(range(int(num_iter))):
# TODO: Retain copies between rejected moves
tree = tree_copy(tree_0)
new_counts = copy.deepcopy(counts)
i, j, found = find_valid_move(tree_0)
if not found:
print("no suitable moves found.")
break
node_i = tree.find_node_with_label(i.label)
node_j = tree.find_node_with_label(j.label)
j_parent = node_j.parent_node
new_counts, rerooted = move(tree, node_i, node_j, new_counts)
if rerooted:
new_counts = update_counts_after_reroot(new_counts, node_i, j_parent, tree.seed_node)
else:
lca = get_lca(tree, i, j)
i_grandparent = i.parent_node.parent_node
new_counts = update_counts(new_counts, node_i, i_grandparent, j_parent, lca)
entropy = compute_tree_entropy(tree, new_counts)
if entropy < current_entropy:
tree_0 = tree
counts = new_counts
current_entropy = entropy
total_moves += 1
tree_0.write(path=out_nwk,schema="newick",
suppress_internal_node_labels=True,suppress_edge_lengths=True, suppress_rooting=True)
print(f" Entropy: {starting_entropy:.2f} --> {current_entropy:.2f} after {total_moves}/{num_iter} accepted moves.")
#Save final tree
tree_0.write(
path=out_nwk,
schema="newick",
suppress_internal_node_labels=True,
suppress_edge_lengths=True,
suppress_rooting=True,
)
if __name__ == "__main__":
main()
# import io
# import pstats
# import cProfile
# with cProfile.Profile() as pr:
# main()
# stream = io.StringIO()
# stats = pstats.Stats(pr, stream=stream)
# stats.sort_stats("cumtime")
# stats.print_stats()
# with open('mc_stats_counter.txt', 'r') as f:
# mcsc = f.read()
# with open(f'mc_stats_{mcsc}.txt', 'w') as f:
# f.write(stream.getvalue())
# with open(f'mc_stats_counter.txt', 'w') as f:
# f.write(str(int(mcsc) + 1))