-
Notifications
You must be signed in to change notification settings - Fork 2
/
INoDS_convenience_functions.py
553 lines (440 loc) · 22.5 KB
/
INoDS_convenience_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
import csv
import networkx as nx
import numpy as np
import scipy.stats as ss
from random import shuffle
import matplotlib.pyplot as plt
import pandas as pd
from itertools import combinations
###########################################################################
def can_nodes_recover(infection_type):
r"""INoDS can handle the following infection model types = SI, SIR, SIS.
recovery times imputed in SIR and SIS models
"""
infection_type = infection_type.upper()
if infection_type not in ["SI", "SIR", "SIS"]:
raise ValueError("INoDS works only for SI, SIR, SIS infection models. User-specified infection model is ill-defined")
if infection_type =="SI": recovery_prob=False
else: recovery_prob=True
return recovery_prob
#################################################################################
def extract_maxtime(edge_filename, health_filename):
r""" Return min of maximum time across edge_filename and health_filename
"""
df = pd.read_csv(edge_filename)
df.columns = df.columns.str.lower()
df.columns = [x.strip().replace('_', '') for x in df.columns]
df2 = pd.read_csv(health_filename)
df2.columns = df2.columns.str.lower()
df2.columns = [x.strip().replace('_', '') for x in df2.columns]
header= list(df)
if "timestep" not in header:return max(df2['timestep'])
else: return min(max(df['timestep']), max(df2['timestep']))
#################################################################################
def create_dynamic_network(edge_filename, complete_nodelist, edge_weights_to_binary, normalize_edge_weight, is_network_dynamic, time_max):
df = pd.read_csv(edge_filename)
df.columns = df.columns.str.lower()
header = list(df)
if [str1 for str1 in header[:2]] != ["node1", "node2"]:
raise ValueError("The first two columns in network file should be arranged as named as 'node1', 'node2'")
if "weight" not in header[2]:
raise ValueError("The third column in network file should contain string = weight")
if is_network_dynamic and (len(header)<4 or header[3] != "timestep"):
raise ValueError("Time-stamps are either missing or the column is not labelled as 'timestep'! If network is static then set 'is_network_dynamic' as False")
##rename the weight column as weight
df.rename(columns={header[2]: 'weight'}, inplace=True)
## remove all zero weighted edges
df = df[df.weight!=0]
if not is_network_dynamic:
if "timestep" in header:
raise ValueError("Network dynamic set as False but the infection data has timesteps!")
n_edges = len(df.index)
timelist = [[num]*n_edges for num in range(time_max+1)]
timelist = [val for sublist in timelist for val in sublist]
df=pd.concat([df]*(time_max+1), ignore_index=True)
df['timestep']=timelist
if edge_weights_to_binary:df['weight']=1
if edge_weights_to_binary and normalize_edge_weight:
raise ValueError("Cannot convert edge weights to binary AND normalize edge weights! Choose one")
if not edge_weights_to_binary and normalize_edge_weight:
## If the user asks for edge weight normalization, then calculate total edge weights
## at each time step
#print ("before converting"), min(df['weight']), max(df['weight'])
max_edgewt = max(df['weight'])/1.
df['weight'] = df["weight"]/max_edgewt
#print ("max weight"), max_edgewt, min(df['weight']), max(df['weight'])
G = {}
for time1 in range(df["timestep"].min(), df["timestep"].max()+1): G[time1] = nx.Graph()
for time1 in G:
if complete_nodelist is not None:
complete_nodelist = [str(num) for num in complete_nodelist]
G[time1].add_nodes_from(complete_nodelist)
df_sub= df.loc[df['timestep'] == time1]
df_sub['node1'] = df_sub['node1'].astype(str)
df_sub['node2'] = df_sub['node2'].astype(str)
edge_list = list(zip(df_sub.node1, df_sub.node2))
G[time1].add_edges_from(edge_list)
edge_attr = dict(zip(zip(df_sub.node1, df_sub.node2), df_sub.weight))
nx.set_edge_attributes(G[time1], edge_attr, 'weight')
return G
##########################################################################
def extract_nodelist(edge_filename, health_filename):
dh = pd.read_csv(health_filename)
health_nodelist = dh.Node.unique()
health_nodelist = [str(node) for node in health_nodelist]
df = pd.read_csv(edge_filename)
nodelist1 = list(df.node1.unique())
nodelist2 = list(df.node2.unique())
nodelist = [str(node) for node in nodelist1] + [str(node) for node in nodelist2]+ list(health_nodelist)
nodelist = list(set(nodelist))
return nodelist
#######################################################################
def check_edge_weights(G):
""" Code convergence is better if the edge weights are normalized
at each time step"""
total_edge_wt = []
for time1 in G.keys():
total_wt = sum(nx.get_edge_attributes(G[time1],'weight').values())
total_wt = int(round(total_wt,1))
total_edge_wt.append(total_wt)
if list(set(total_edge_wt))!= [1]:
print ("Warning: Code convergence is better if the edge weights are normalized at each time step" )
#################################################################
def permute_network(G1, permutation_level, complete_nodelist=None, network_dynamic = True):
G2 = {}
if network_dynamic:
for time in G1.keys():
G2[time] = nx.Graph()
if complete_nodelist is not None:
complete_nodelist = [str(num) for num in complete_nodelist]
G2[time].add_nodes_from(complete_nodelist)
else: G2[time].add_nodes_from(list(G1[time].nodes))
wtlist = [G1[time][node1][node2]["weight"] for node1, node2 in list(G1[time].edges)]
edge_size = len(G1[time].edges())
mean_wtlist = np.mean(wtlist)
num_swaps = int(permutation_level*len(G1[time].edges))
num_orig = len(G1[time].edges) - num_swaps
orig_edges = list(G1[time].edges)
shuffle(orig_edges)
## copy edge connections
G2[time].add_edges_from(orig_edges[:num_orig], weight = mean_wtlist)
for num in range(num_swaps):
condition_met = False # skip over node pairs that already have an edge
counter = 0
while not condition_met:
node1, node2 = np.random.choice(G2[time].nodes, 2, replace=False)
if not (G2[time].has_edge(node1, node2) and G1[time].has_edge(node1,node2)):
condition_met = True
G2[time].add_edge(node1, node2)
G2[time][node1][node2]["weight"] = mean_wtlist
else: counter+=1
if counter > 2* edge_size and not (G2[time].has_edge(node1, node2) and G1[time].has_edge(node1,node2)):
##Give up after 2*#edges attempts
condition_met=True
G2[time].add_edge(node1, node2)
G2[time][node1][node2]["weight"] = mean_wtlist
print ("Warning: permutation level not acheived! Could be due to high edge density of the observed network")
else:
init_time = min([time1 for time1 in G1])
G2[init_time] = nx.Graph()
if complete_nodelist is not None:
complete_nodelist = [str(num) for num in complete_nodelist]
G2[init_time].add_nodes_from(complete_nodelist)
else: G2[init_time].add_nodes_from(G1[init_time].nodes())
edge_size = len(G1[init_time].edges())
wtlist = [G1[init_time][node1][node2]["weight"] for node1, node2 in G1[init_time].edges()]
mean_wtlist = np.mean(wtlist)
num_swaps = int(permutation_level*len(G1[init_time].edges))
num_orig = len(G1[init_time].edges) - num_swaps
orig_edges = list(G1[init_time].edges)
shuffle(orig_edges)
## copy edge connections
G2[init_time].add_edges_from(orig_edges[:num_orig], weight = mean_wtlist)
for num in range(num_swaps):
#select two random nodes from G2[time]
condition_met = False # skip over node pairs that already have an edge
counter=0
while not condition_met:
node1, node2 = np.random.choice(G2[init_time].nodes(), 2, replace=False)
if not (G2[init_time].has_edge(node1, node2) and G1[init_time].has_edge(node1,node2)):
condition_met = True
G2[init_time].add_edge(node1, node2)
G2[init_time][node1][node2]["weight"] = mean_wtlist
else:counter+=1
if counter > 4* edge_size and not (G2[init_time].has_edge(node1, node2) and G1[init_time].has_edge(node1,node2)):
##Give up after 2*#edges attempts
condition_met=True
G2[init_time].add_edge(node1, node2)
G2[init_time][node1][node2]["weight"] = mean_wtlist
print ("Warning: permutation level not acheived! Could be due to high edge density of the observed network")
time_list = [time1 for time1 in G1 if time1 != init_time]
for time1 in time_list:
G2[time1] = G2[init_time].copy()
return G2
#############################################################################
def randomize_edges(g):
"""randomize a network using double-edged swaps"""
size = g.size() # number of edges in graph
nx.double_edge_swap(g, nswap = 100*size, max_tries = 10000000*size)
###################################################################
def randomize_network(G1, complete_nodelist, network_dynamic = True):
r""" Randomize edge connections of each network slice.
Also, set edge weight to mean.
"""
G2 = {}
if network_dynamic:
for time in G1.keys():
G2[time] = nx.Graph()
if complete_nodelist is not None:
complete_nodelist = [str(num) for num in complete_nodelist]
G2[time].add_nodes_from(complete_nodelist)
else: G2[time].add_nodes_from(G1[time].nodes())
edge_size = len(G1[time].edges())
wtlist = [G1[time][node1][node2]["weight"] for node1, node2 in G1[time].edges()]
mean_wtlist = np.mean(wtlist)
for num in range(len(G1[time].edges())): #for each edge in G1[time]
#select two random nodes from G2[time]
condition_met = False # skip over node pairs that already have an edge
counter=0
while not condition_met:
node1, node2 = np.random.choice(G2[time].nodes(), 2, replace=False)
if not (G2[time].has_edge(node1, node2) or G1[time].has_edge(node1, node2)):
condition_met = True
G2[time].add_edge(node1, node2)
G2[time][node1][node2]["weight"] = mean_wtlist
else:counter+=1
if counter > 4* edge_size and not (G2[time].has_edge(node1, node2)):
##Give up after 2*#edges attempts
condition_met=True
G2[time].add_edge(node1, node2)
G2[time][node1][node2]["weight"] = mean_wtlist
else:
init_time = min([time1 for time1 in G1])
G2[init_time] = nx.Graph()
if complete_nodelist is not None:
complete_nodelist = [str(num) for num in complete_nodelist]
G2[init_time].add_nodes_from(complete_nodelist)
else: G2[init_time].add_nodes_from(G1[init_time].nodes())
edge_size = len(G1[init_time].edges())
wtlist = [G1[init_time][node1][node2]["weight"] for node1, node2 in G1[init_time].edges()]
mean_wtlist = np.mean(wtlist)
for num in range(len(G1[init_time].edges())): #for each edge in G1[time]
#select two random nodes from G2[time]
condition_met = False # skip over node pairs that already have an edge
counter=0
while not condition_met:
node1, node2 = np.random.choice(G2[init_time].nodes(), 2, replace=False)
if not (G2[init_time].has_edge(node1, node2) or G1[init_time].has_edge(node1, node2)):
condition_met = True
G2[init_time].add_edge(node1, node2)
G2[init_time][node1][node2]["weight"] = mean_wtlist
else:counter+=1
if counter > 4* edge_size and not (G2[init_time].has_edge(node1, node2)):
##Give up after 2*#edges attempts
condition_met=True
G2[init_time].add_edge(node1, node2)
G2[init_time][node1][node2]["weight"] = mean_wtlist
time_list = [time1 for time1 in G1 if time1 != init_time]
for time1 in time_list:
G2[time1] = G2[init_time].copy()
jaccard = calculate_mean_temporal_jaccard(G1, G2)
return G2, jaccard
#######################################################################
def stitch_health_data(health_data):
""" Fill in time steps with same infection status"""
for node in health_data.keys():
timelist = health_data[node].keys()
timelist = sorted(timelist)
for num in range(1, len(timelist)):
time2 = timelist[num]
time1 = timelist[num-1]
if health_data[node][time1]==health_data[node][time2]:
for step in range(time1+1, time2): health_data[node][step] = health_data[node][time2]
return health_data
#######################################################################
def extract_health_data(health_filename, infection_type, nodelist, time_max, diagnosis_lag=False):
r"""node_health is a dictionary of dictionary. Primary key = node id.
Secondary key = 0/1. 0 (1) key stores chunk of days when the node is **known** to be healthy (infected).
Dates stored as tuple of (start date, end date)"""
health_data = {}
for node in nodelist: health_data[str(node)]={}
with open (health_filename, 'r') as csvfile:
fileread = csv.reader(csvfile, delimiter = ',')
next(fileread, None) #skip header
for row in fileread:
node = str(row[0])
timestep = int(row[1])
diagnosis = int(row[2])
if node in nodelist and timestep<=time_max:health_data[node][timestep] = diagnosis
if diagnosis_lag: health_data = stitch_health_data(health_data)
node_health = {}
for node in health_data.keys():
sick_list_node=[]
healthy_list_node=[]
node_health[node] = {}
## sort infection reported into infected (sick_list) and healthy (healthy_list) lists
for time1 in health_data[node].keys():
if health_data[node][time1]==1: sick_list_node.append(time1)
if health_data[node][time1]==0: healthy_list_node.append(time1)
if len(healthy_list_node)>0:
healthy_list_node = sorted(healthy_list_node)
node_health[node][0] = select_healthy_time(healthy_list_node, node, health_data, infection_type)
if len(sick_list_node)>0:
sick_list_node = sorted(sick_list_node)
node_health[node][1]= select_sick_times(sick_list_node, node, health_data)
if 1 in node_health[node]:
for time1, time2 in node_health[node][1]:
##impute the missing report of sick in health data dictionary
for day in range(time1, time2+1): health_data[node][day]=1
return health_data, node_health
##############################################################################
def select_healthy_time(healthy_list_node, node, health_data, infection_type):
r""" Select chunks of time-periods (from healthy_list_node) for which the node
is reported uninfected"""
healthy_times = []
#if SIR type of infection
if infection_type[-1].lower()=='r':
#min date = if there was no report before the day
min_date = [time for time in healthy_list_node if len([val for key, val in health_data[node].items() if key<time])==0]
else:
#min date = if there is (any report before focal date AND the last report is sick) OR there is no report before the day
min_date = [time for time in healthy_list_node if (len([val for key, val in health_data[node].items() if key<time])>0 and health_data[node][max([key for key in health_data[node] if key < time])]==1) or len([val for key, val in health_data[node].items() if key<time])==0]
#max_date = if there is (any report after focal date AND the report is sick) OR there is no report after the day
max_date = [time for time in healthy_list_node if (len([val for key, val in health_data[node].items() if key> time])>0 and health_data[node][min([key for key in health_data[node] if key > time])]==1) or len([val for key, val in health_data[node].items() if key>time])==0]
min_date = sorted(min_date)
max_date = sorted(max_date)
for day1, day2 in zip(min_date, max_date): healthy_times.append((day1, day2))
return healthy_times
############################################################################
def select_sick_times(sick_list_node, node, health_data):
r""" Select chunks of time-periods (from sick_list_node) for which the node
is reported sick"""
sick_times = []
#min date = if there is (any report before the focal date AND the last report therin is healthy) OR there is no report before the day
min_date = [time for time in sick_list_node if (len([val for key, val in health_data[node].items() if key<time])>0 and health_data[node][max([key for key in health_data[node] if key < time])]==0) or len([val for key, val in health_data[node].items() if key<time])==0]
#max date = if there is (any report after the focal date AND the first report theirin is healthy) OR there is no report after the focal date
max_date = [time for time in sick_list_node if (len([val for key, val in health_data[node].items() if key> time])>0 and health_data[node][min([key for key in health_data[node] if key > time])]==0) or len([val for key, val in health_data[node].items() if key>time])==0]
min_date = sorted(min_date)
max_date = sorted(max_date)
for day1, day2 in zip(min_date, max_date): sick_times.append((day1, day2))
return sick_times
#########################################################################
def return_contact_days_sick_nodes(node_health, seed_date, G_raw):
r"""If infection diagnosis is lagged, then true infection day is
inferred using infectious contact history of the focal node """
contact_daylist={key:{} for key in G_raw}
## select all nodes that were reported infected and sort
for node in sorted([node1 for node1 in node_health.keys() if 1 in node_health[node1]]):
## removing seed nodes
sick_days = [(time1, time2) for (time1, time2) in sorted(node_health[node][1]) if time1!= seed_date]
##for all time periods when the node was reported sick
for time1, time2 in sick_days:
#default day start
day_start =1
if 0 in node_health[node]:
##choose all uninfected time-periods before the focal sick period
healthy_dates = [(healthy_day1, healthy_day2) for healthy_day1, healthy_day2 in node_health[node][0] if healthy_day2 < time1]
if len(healthy_dates)>0:
##choose the latest uninfected period
lower_limit, upper_limit = max(healthy_dates, key=lambda x:x[1])
##choose the last ever reported time-point of being uninfected
day_start = upper_limit
##contact_daylist is dictionary. Primary key = network type. Could be network hypothesis or null network
## secondary key (node1, time1, time2) indicated focal node and its infected time period
##values are the time-points when the node could have potentially contracted infection
for network in G_raw:
#choose only those days where nodes has contact the previous day
contact_daylist[network][(node, time1, time2)] =[day for day in range(day_start+1, time1+1) if (day-1) in G_raw[network] and node in G_raw[network][day-1].nodes() and G_raw[network][day-1].degree(node)>0]
##if there are no contacts then return the entire list
if len(contact_daylist[network][(node, time1, time2)])==0:
contact_daylist[network][(node, time1, time2)] = [day for day in range(day_start+1, time1+1)]
return contact_daylist
#########################################################################
def return_potention_recovery_date(node_health, time_max):
r""" For SIR/SIS model. Returns the potential time-points of recovery for each
infected focal node"""
recovery_daylist = {}
## select all nodes that were reported infected and sort
for node in sorted([node1 for node1 in node_health.keys() if 1 in node_health[node1]]):
## sort sick days for the focal node
sick_days = sorted(node_health[node][1])
for time1, time2 in sick_days:
if 0 in node_health[node]:
##choose all uninfected time-periods after the focal sick period
healthy_dates = [(healthy_day1, healthy_day2) for healthy_day1, healthy_day2 in node_health[node][0] if healthy_day2 > time1]
if len(healthy_dates)>0:
##choose the first report of uninfection
lower_limit, upper_limit = min(healthy_dates, key=lambda x:x[1])
recovery_date = lower_limit
else: recovery_date = time_max
else: recovery_date = time_max
##recovery date can be any time-point between the last report of node "infection" state to the the first
## report of uninfection afterwards (or time_max of the study)
recovery_daylist[(node, time1, time2)] = recovery_date
return recovery_daylist
####################################################################################
def find_seed_date(node_health):
r"""Find the first time-period where an infection was reported"""
sick_dates = [val for node in node_health.keys() for key,val in node_health[node].items() if key==1]
#flatten list
sick_dates = [item for sublist in sick_dates for item in sublist]
#pick the first date
sick_dates=[num[0] for num in sick_dates]
#sort tuple according to increasing infection dates
sick_dates = sorted(sick_dates)
# pick out the first reported infection date
seed_date = sick_dates[0]
return seed_date
########################################################################
def calculate_jaccard(g1, g2):
edges1 = [tuple(sorted(num)) for num in g1.edges()]
edges2 = [tuple(sorted(num)) for num in g2.edges()]
w11 = len(list(set(edges1) & set(edges2)))
w10 = len(list(set(edges1) - set(edges2)))
w01 = len(list(set(edges2) - set(edges1)))
ratio = w11/ (1.*(w11+w10+w01))
return ratio
########################################################################
def calculate_mean_temporal_jaccard(g1, g2):
jlist = []
for time1 in g1:
edges1 = [tuple(sorted(num)) for num in g1[time1].edges()]
edges2 = [tuple(sorted(num)) for num in g2[time1].edges()]
w11 = len(list(set(edges1) & set(edges2)))
w10 = len(list(set(edges1) - set(edges2)))
w01 = len(list(set(edges2) - set(edges1)))
if (w11+w10+w01)>0:ratio = w11/ (1.*(w11+w10+w01))
else:ratio=0
jlist.append(ratio)
return np.mean(jlist)
########################################################
def compute_diagnosis_lag_truth(graph, contact_datelist, filename):
diag_date = {}
infection_date={}
lag_truths = []
with open (filename, 'r') as csvfile:
fileread = csv.reader(csvfile, delimiter = ',')
next(fileread, None) #skip header
for row in fileread:
node = row[0]
timestep = int(row[1])
diagnosis = int(row[2])
if diagnosis==1: infection_date[node] = timestep
for node, time1, time2 in sorted(contact_datelist):
daylist = [day for day in contact_datelist[(node, time1, time2)] if graph[day-1].degree(node)>0]
pos = [pos for pos, date in enumerate(daylist) if date==infection_date[node]][0]
lag_truths.append(ss.randint.cdf(pos, 0, len(daylist)))
return lag_truths
######################################################333
def plot_beta_results(sampler, filename):
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(15, 6))
ax1.plot(sampler.chain[ :, :, 0].T, color="k", lw=0.1)
ax1.set_ylabel("Walker positions for $beta$")
ax1.set_xlabel("Simulation step")
samples = sampler.chain[ :, :, 0].reshape((-1, 1))
ax2.hist(samples, bins=50, histtype="step", normed=True, label="posterior", color="k", linewidth=2)
ax2.legend(frameon=False, loc="best")
ax2.set_xlabel("$beta$ posterior")
plt.tight_layout()
plt.savefig(filename)
########################################################################