-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_model.py
138 lines (123 loc) · 5.14 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from src.model import configSimulation, simulationLoop
import jax
import sys
import time
import os
from functools import partial
from jax import block_until_ready, jit
import matplotlib.pyplot as plt
import numpy as np
os.chdir(os.path.dirname(__file__))
jax.config.update("jax_enable_x64", True)
config_filename = ""
if len(sys.argv) == 1:
# base cases
#modelname = "single-artery"
#modelname = "tapering"
#modelname = "conjunction"
#modelname = "bifurcation"
#modelname = "aspirator"
# openBF-hub
modelname = "test/adan56/adan56.yml"
# vascularmodels.com
#modelname = "0007_H_AO_H"
#modelname = "0029_H_ABAO_H"
#modelname = "0053_H_CERE_H"
input_filename = "test/" + modelname + "/" + modelname + ".yml"
else:
config_filename = "test/" + sys.argv[1] + "/" + sys.argv[1] + ".yml"
verbose = True
(N, B, J,
sim_dat, sim_dat_aux,
sim_dat_const, sim_dat_const_aux,
timepoints, conv_tol, Ccfl, edges, input_data, rho,
masks, strides, edges,
vessel_names, cardiac_T) = configSimulation(config_filename, verbose)
if verbose:
starting_time = time.time_ns()
sim_loop_old_jit = partial(jit, static_argnums=(0, 1, 2))(simulationLoop)
sim_dat, t, P = block_until_ready(sim_loop_old_jit(N, B, J,
sim_dat, sim_dat_aux,
sim_dat_const, sim_dat_const_aux,
timepoints, conv_tol, Ccfl, input_data, rho,
masks, strides, edges))
if verbose:
ending_time = (time.time_ns() - starting_time) / 1.0e9
print(f"elapsed time = {ending_time} seconds")
# save data for unittests
#np.savetxt("test/test_data/bifurcation_sim_dat.dat", sim_dat)
#np.savetxt("test/test_data/bifurcation_t.dat", t)
#np.savetxt("test/test_data/bifurcation_P.dat", P)
#jnp.set_printoptions(threshold=sys.maxsize)
filename = config_filename.split("/")[-1]
network_name = filename.split(".")[0]
vessel_names_0007 = ["ascending aorta", "right subclavian artery", "right common carotid artery",
"arch of aorta I", "brachiocephalic artery",
"arch of aorta II",
"left common carotid artery",
"left subclavian artery",
"descending aorta",
]
vessel_names_0029 = [
"aorta I",
"left common iliac artery I",
"left internal iliac artery",
"left common iliac artery II",
"right common iliac artery I",
"celiac trunk II",
"celiac branch",
"aorta IV",
"left renal artery",
"aorta III",
"superior mesentric artery",
"celiac trunk I",
"aorta II",
"aorta V",
"right renal artery",
"right common iliac artery II",
"right internal iliac artery",
]
vessel_names_0053 = [
"right vertebral artery I",
"left vertebral artery II",
"left posterior meningeal branch of vertebral artery",
"basilar artery III",
"left anterior inferior cerebellar artery",
"basilar artery II",
"right anterior inferior cerebellar artery",
"basilar artery IV",
"right superior cerebellar artery",
"basilar artery I",
"left vertebral artery I",
"right posterior cerebellar artery I",
"left superior cerebellar artery",
"left posterior cerebellar artery I",
"right posterior central artery",
"right vertebral artery II",
"right posterior meningeal branch of vertebral artery",
"right posterior cerebellar artery II",
"right posterior comunicating artery",
]
#plt.rcParams.update({'font.size': 20})
for i,vessel_name in enumerate(vessel_names):
index_vessel_name = vessel_names.index(vessel_name)
P0 = np.loadtxt("/home/diego/studies/uni/thesis_maths/openBF/test/" + network_name + "/" + network_name + "_results/" + vessel_name + "_P.last")
node = 2
index_jl = 1 + node
index_jax = 5*index_vessel_name + node
P0 = P0[:,index_jl]
res = np.sqrt(((P[:,index_jax]-P0).dot(P[:,index_jax]-P0)/P0.dot(P0)))
_, ax = plt.subplots()
ax.set_xlabel("t[s]")
ax.set_ylabel("P[mmHg]")
plt.title("network: " + network_name + ", # vessels: " + str(N) + ", vessel name: " + vessel_names[i] + ", \n relative error = |P_JAX-P_jl|/|P_jl| = " + str(res) + "%")
#plt.title("network: " + network_name + ", vessel name: " + vessel_names_0053[i])
#plt.title(vessel_names_0053[i])
#plt.title("vessel name: " + vessel_name)
plt.plot(t%cardiac_T,P[:,index_jax]/133.322)
plt.plot(t%cardiac_T,P0/133.322)
#plt.legend(["P_JAX", "P_jl"], loc="lower right")
#plt.axis("off")
plt.tight_layout()
plt.savefig("results/" + network_name + "_results/" + network_name + "_" + vessel_names[i].replace(" ", "_") + "_P.pdf")#, bbox_inches='tight')
plt.close()