-
Notifications
You must be signed in to change notification settings - Fork 0
/
linear_threshold_model_association.py
223 lines (172 loc) · 9.58 KB
/
linear_threshold_model_association.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# -*- coding: utf-8 -*-
"""
@author: ludovico coletta
@author: gabriele de leonardis
"""
import numpy as np
import pandas as pd # type: ignore
import sys
import time
from multiprocessing import Pool
def run_cascade_single_population(adj_matrix, thr, seed_node_index):
infected_nodes = np.zeros((adj_matrix.shape[0]))
input_to_node = np.sum(adj_matrix, axis=0)
infected_nodes[seed_node_index] = 1
list_of_infected_nodes_per_iter = []
list_of_infected_nodes_per_iter.append(np.where(infected_nodes == 1)[0].tolist())
counter = 0
while int(np.sum(infected_nodes)) < adj_matrix.shape[0]:
if counter > 30:
break
indices_of_infected_nodes = np.where(infected_nodes == 1)[0]
mask_array = np.zeros((adj_matrix.shape))
mask_array[indices_of_infected_nodes, :] = 1
mask_array[:, indices_of_infected_nodes] = 1
infected_connections = adj_matrix.copy()
infected_connections = infected_connections * mask_array
infected_inputs = np.sum(infected_connections, axis=0)
infected_nodes_indices = np.where(infected_inputs / input_to_node > thr)[0]
list_of_infected_nodes_per_iter.append(infected_nodes_indices.tolist())
infected_nodes[infected_nodes_indices] = 1
counter += 1
return list_of_infected_nodes_per_iter
def find_thr(adj_matrix, starting_thr):
visited_thresholds_per_node = [None] * adj_matrix.shape[0]
for seed_node_index in range(adj_matrix.shape[0]):
visited_thresholds_per_node[seed_node_index] = []
thr = starting_thr
for dummy_thr in range(1000):
list_of_infected_nodes_per_iter = run_cascade_single_population(adj_matrix, thr, seed_node_index)
if len(list_of_infected_nodes_per_iter[-1]) == adj_matrix.shape[0]:
thr *= 2
visited_thresholds_per_node[seed_node_index].append(thr)
elif dummy_thr == 0 and len(list_of_infected_nodes_per_iter[-1]) != adj_matrix.shape[0]:
thr /= 100
else:
break
# debug
print(f"Node {seed_node_index}: visited thresholds = {visited_thresholds_per_node[seed_node_index]}")
max_thresholds_per_node = np.asarray([visited_thresholds_per_node[ii][-1] for ii in range(len(visited_thresholds_per_node))])
bottleneck_node = np.where(max_thresholds_per_node == np.min(max_thresholds_per_node))[0]
# ensure at least two thresholds before accessing them
if len(visited_thresholds_per_node[bottleneck_node[0]]) < 2:
raise ValueError(f"Not enough thresholds visited for bottleneck node {bottleneck_node[0]}: {visited_thresholds_per_node[bottleneck_node[0]]}")
thrs = np.linspace(visited_thresholds_per_node[bottleneck_node[0]][-2], visited_thresholds_per_node[bottleneck_node[0]][-1], 100, endpoint=True)
visited_thresholds_of_bottleneck_node = []
visited_thresholds_of_bottleneck_node.append(thrs[0])
final_thr_per_node = []
for node in bottleneck_node:
for final_thr in thrs:
list_of_infected_nodes_per_iter = run_cascade_single_population(adj_matrix, final_thr, node)
if len(list_of_infected_nodes_per_iter[-1]) == adj_matrix.shape[0]:
visited_thresholds_of_bottleneck_node.append(final_thr)
else:
break
final_thr_per_node.append(visited_thresholds_of_bottleneck_node[-1])
return bottleneck_node, np.min(final_thr_per_node)
def run_cascade_multiple_populations(adj_matrix, thr, n_pop, n_sim):
infected_nodes_per_run = []
counter_sim = 0
association_matrix = np.zeros((adj_matrix.shape[0], adj_matrix.shape[0]))
while len(infected_nodes_per_run) < n_sim:
seed_node_indices = sorted(np.random.choice(adj_matrix.shape[0], size=n_pop, replace=False).tolist())
seed_node_indices = [[ii] for ii in seed_node_indices]
infected_nodes = np.zeros((adj_matrix.shape[0]))
input_to_node = np.sum(adj_matrix, axis=0)
infected_nodes[seed_node_indices] = 1
stuck = 0
while int(np.sum(infected_nodes)) < adj_matrix.shape[0]:
list_of_potential_infected_nodes_within_iter_per_pop = [[None]] * n_pop
for seed_node_index, node_infected in enumerate(seed_node_indices):
mask_array = np.zeros((adj_matrix.shape))
mask_array[node_infected, :] = 1
mask_array[:, node_infected] = 1
infected_connections = adj_matrix.copy()
infected_connections = infected_connections * mask_array
infected_inputs = np.sum(infected_connections, axis=0)
potential_infected_nodes_indices = infected_inputs / input_to_node
list_of_potential_infected_nodes_within_iter_per_pop[seed_node_index] = potential_infected_nodes_indices
input_per_node = np.vstack(list_of_potential_infected_nodes_within_iter_per_pop)
nodes_to_check = np.where(input_per_node >= thr)[1].tolist()
dummy_list = [elem for sublist in seed_node_indices for elem in sublist]
nodes_to_check = sorted(list(set([ii for ii in nodes_to_check if ii not in dummy_list])))
if len(nodes_to_check) == 0:
#print('I got stuck')
stuck = stuck + 1
break
else:
for node in nodes_to_check:
indices_of_winner = np.where(input_per_node[:, node] == np.max(input_per_node[:, node]))[0]
if indices_of_winner.size == 1:
seed_node_indices[indices_of_winner[0]].append(node)
else:
seed_node_indices[np.random.choice(indices_of_winner.size, size=1)[0]].append(node)
for node in seed_node_indices:
infected_nodes[node] = 1
if stuck == 0:
infected_nodes_per_run.append(seed_node_indices)
for node_set in seed_node_indices:
for i in node_set:
for j in node_set:
if i != j:
association_matrix[i, j] += 1
elif stuck == 1:
infected_nodes_per_run.pop()
counter_sim += 1
association_matrix /= n_sim
return infected_nodes_per_run, association_matrix
def main(input_file_path, n_pop):
# extract subject ID from the file path
sub_id = input_file_path.split('/')[-3]
adj_matrix = pd.read_csv(input_file_path, delimiter=',', header=None).to_numpy().astype(float)
#adj_matrix = pd.read_csv('dummy_matrix_2.csv', header=None).to_numpy().astype(float)
print(f"now processing: {sub_id} with {n_pop} seeds competitive scenario")
# identify zero connections and <= 5 connection nodes
zero_rows = np.where(np.sum(adj_matrix, 0) == 0)[0].tolist()
low_connection_nodes = np.where(np.sum(adj_matrix > 0, axis=0) <= 5)[0].tolist()
# combine together
all_removed_nodes = sorted(set(zero_rows + low_connection_nodes))
# matrix filled with ones initially
zero_connection_nodes_matrix = np.ones_like(adj_matrix, dtype=int)
# update rows and cols corresponding to zero/low-connection nodes to 0
zero_connection_nodes_matrix[all_removed_nodes, :] = 0
zero_connection_nodes_matrix[:, all_removed_nodes] = 0
# remove those nodes from input matrix
adj_matrix_clean = np.delete(adj_matrix, all_removed_nodes, axis=0)
adj_matrix_clean = np.delete(adj_matrix_clean, all_removed_nodes, axis=1)
#iu2 = np.triu_indices(adj_matrix_clean.shape[0], 1)
#a = adj_matrix_clean[iu2]
#density = np.count_nonzero(a) / a.shape[0]
starting_thr = 0.0015
start_time = time.time()
bottl_nodes, thr = find_thr(adj_matrix_clean, starting_thr)
print(f"Time to find threshold: {time.time() - start_time} seconds")
n_steps_needed = [None] * adj_matrix_clean.shape[0]
for ii in range(len(n_steps_needed)):
n_steps_needed[ii] = len(run_cascade_single_population(adj_matrix_clean, thr, ii))
start_time = time.time()
_, association_matrix = run_cascade_multiple_populations(adj_matrix_clean, thr, n_pop, 10000)
print(f"Time to run competitive cascades: {time.time() - start_time} seconds")
association_matrix_filename = f"derivatives/{sub_id}/dwi/association_matrix_{sub_id}_{n_pop}seeds.csv"
removed_nodes_filename = f"derivatives/{sub_id}/dwi/removed_nodes_{sub_id}_{n_pop}seeds.csv"
np.savetxt(association_matrix_filename, association_matrix, delimiter=",")
np.savetxt(removed_nodes_filename, zero_connection_nodes_matrix, delimiter=",", fmt="%d")
if __name__ == "__main__":
# Check if script is executed with correct num of args
if len(sys.argv) != 2:
print("Correct syntax: cat input_file_list | python [this_script.py] [n_pop]")
sys.exit(1)
input_file_paths = [line.strip() for line in sys.stdin]
n_pop = int(sys.argv[1])
pool = Pool(processes=142)
pool.starmap(main, [(file_path, n_pop) for file_path in input_file_paths])
########## HOW TO RUN ###########
# from terminal (bash), cd to dataset folder
# you should also have a folder called "code" in position ../code
# choose number of populations
# therefore type these commands:
"""
path_der="derivatives/"
find "$path_der" -type f -name '*5000000mio_connectome.csv' > "$path_der/connectome_files.txt"
cat "$path_der/connectome_files.txt" | python ../code/linear-threshold-model/linear_threshold_model_association.py [n_pop] > sim_parallel.txt
"""