-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.py
85 lines (71 loc) · 2.45 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
This script allows you to draw a graph of the data from a data.json file generated by the ETL process.
Run it with: `python graph.py`
"""
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
def load_data() -> pd.DataFrame:
"""
Loads the data.json file.
"""
with open('data.json', 'r', encoding="UTF-8") as f:
return pd.read_json(f, convert_dates=False)
def create_pandas_edgelist(data: pd.DataFrame) -> pd.DataFrame:
"""
Creates a pandas edgelist DataFrame from the data.json file.
"""
pandas_edgelist = pd.DataFrame()
for index, row in data.iterrows():
try:
pandas_edgelist = pandas_edgelist.append(
{
"source": row["drug"]["drug"],
"target": row["article"]["title"],
"relationship": row["relationship"],
"date": row["date"]
},
ignore_index=True
)
except KeyError:
pandas_edgelist = pandas_edgelist.append(
{
"source": row["drug"]["drug"],
"target": row["article"]["scientific_title"],
"relationship": row["relationship"],
"date": row["date"]
},
ignore_index=True
)
for index, row in data.iterrows():
pandas_edgelist = pandas_edgelist.append(
{
"source": row["drug"]["drug"],
"target": row["journal"],
"relationship": row["relationship"],
"date": row["date"]
},
ignore_index=True
)
return pandas_edgelist
def graph() -> bool:
"""
Draws a graph of the data.json file.
"""
data = load_data()
graph = nx.from_pandas_edgelist(
create_pandas_edgelist(data),
edge_attr=True,
create_using=nx.DiGraph()
)
pos = nx.spring_layout(graph)
edge_labels = dict([((n1, n2), f'{attributes["relationship"]} the {attributes["date"]}')
for n1, n2, attributes in graph.edges(data=True)])
plt.figure(figsize=(12, 8))
nx.draw_networkx(graph, pos, font_size=6, node_size=1000)
nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=5)
plt.show()
return True
if __name__ == '__main__':
if graph():
print('Graph created successfully.')