Skip to content

Commit b6b993e

Browse files
authored
Merge pull request #1 from asapdiscovery/tweaks_for_demo
Tweaks
2 parents 7ca3317 + d4b2308 commit b6b993e

File tree

1 file changed

+141
-62
lines changed

1 file changed

+141
-62
lines changed

app.py

Lines changed: 141 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
asap_prod_streamlit = int(os.getenv("ASAP_OE_PROD_STREAMLIT", None))
55

66
if asap_prod_streamlit == 1:
7+
78
def sort_out_openeye_license():
89
import os
10+
911
# need to write the license file to disk
1012
txt = st.secrets.openeye_credentials.license_file_txt
1113
if not txt:
@@ -18,11 +20,11 @@ def sort_out_openeye_license():
1820
sort_out_openeye_license()
1921

2022

21-
2223
import pandas as pd
2324
import numpy as np
2425
from asapdiscovery.ml.inference import GATInference
2526
from asapdiscovery.ml.models import ASAPMLModelRegistry
27+
import matplotlib.pyplot as plt
2628
from rdkit import Chem
2729
from streamlit_ketcher import st_ketcher
2830
from io import StringIO
@@ -31,6 +33,7 @@ def sort_out_openeye_license():
3133
# need to update the registry periodically
3234
schedule.every(4).hours.do(ASAPMLModelRegistry.update_registry)
3335

36+
3437
def _is_valid_smiles(smi):
3538
if smi is None or smi == "":
3639
return False
@@ -42,9 +45,11 @@ def _is_valid_smiles(smi):
4245
return True
4346
except:
4447
return False
45-
48+
49+
4650
def sdf_str_to_rdkit_mol(sdf):
4751
from io import BytesIO
52+
4853
bio = BytesIO(sdf.encode())
4954
suppl = Chem.ForwardSDMolSupplier(bio, removeHs=False)
5055
mols = [mol for mol in suppl if mol is not None]
@@ -57,39 +62,46 @@ def convert_df(df):
5762
return df.to_csv().encode("utf-8")
5863

5964

60-
61-
62-
6365
# Set the title of the Streamlit app
64-
st.title('ASAPDiscovery Machine Learning')
66+
st.title("ASAP Discovery Local Models (ML)")
6567

66-
st.markdown("## Intro")
68+
st.markdown("## Background")
6769

68-
st.markdown("The [ASAPDiscovery antiviral drug discovery consortium](https://asapdiscovery.org) has developed a series of machine learning models (primarily Graph Attention Networks (GATs)) to predict molecular properties based on our experimental data, much of which is [available](https://asapdiscovery.org/outputs/) as part of our [open science](https://asapdiscovery.org/open-science/) and public disclosure policy.")
69-
st.markdown("These models are trained on a variety of endpoints, including in-vitro activity, assayed LogD, and more \n Some models are specific to a target, while others are global models that predict properties across all targets.")
70-
st.markdown("This web app gives you easy API-less access to the models, I hope you find it useful!\n As scientists we should always be looking to get our models into people's hands as easily as possible.")
71-
st.markdown("These models are trained bi-weekly. The latest models are used for prediction where possible. Note that predictions are pre-alpha and are provided as is, we are still working very actively on improving and validating models.")
70+
st.markdown(
71+
"**The [ASAP Discovery antiviral drug discovery consortium](https://asapdiscovery.org) has developed a series of local machine learning models (GAT architecture) to predict properties based on our local data, much of which is [available](https://asapdiscovery.org/outputs/) as part of our [open science policy](https://asapdiscovery.org/open-science/).**"
72+
)
73+
st.markdown(
74+
"**These models are trained on a variety of experimental endpoints that are found in ASAP's CDD vault, including biochemical and antiviral potency, assayed LogD, and more. Some models are specific to a target, while others are global models that predict properties across all targets.**"
75+
)
76+
st.markdown(
77+
"This web app gives you easy access to the trained models without having to write or execute any code. The intention is to empower anyone across ASAP to make these predictions."
78+
)
79+
st.markdown("---")
80+
st.markdown(
81+
"These models are trained bi-weekly. The latest models are used for prediction where possible. Note that predictions are pre-alpha and are provided as is, work is on-going on improving and validating these models. As a general rule of thumb, predictions on your data will be better when your query compound(s) is/are closer chemically to the compounds in the CDD. Are you having problems using this UI or do you have a feature request? Please open an issue on [our issue tracker](https://github.com/asapdiscovery/asap-ml-streamlit/issues/new)."
82+
)
7283

73-
st.markdown("## Select input")
84+
st.markdown("## Input :clipboard:")
7485

7586

76-
input = st.selectbox("How would you like to enter your input?", ["Upload a CSV file", "Draw a molecule", "Enter SMILES", "Upload an SDF file"])
87+
input = st.selectbox(
88+
"How would you like to enter your input?",
89+
["Upload a CSV file", "Draw a molecule", "Enter SMILES", "Upload an SDF file"],
90+
)
7791

7892
multismiles = False
7993
if input == "Draw a molecule":
80-
st.write("Draw a molecule")
8194
smiles = st_ketcher(None)
8295
if _is_valid_smiles(smiles):
83-
st.success("Valid SMILES string", icon="✅")
96+
st.success("Valid molecule", icon="✅")
8497
else:
85-
st.error("Invalid SMILES string", icon="🚨")
98+
st.error("Invalid molecule", icon="🚨")
8699
st.stop()
87100
smiles = [smiles]
88101
df = pd.DataFrame(smiles, columns=["SMILES"])
89102
column = "SMILES"
90103
smiles_column = df["SMILES"]
91104
elif input == "Enter SMILES":
92-
st.write("Enter SMILES")
93105
smiles = st.text_input("Enter a SMILES string")
94106
if _is_valid_smiles(smiles):
95107
st.success("Valid SMILES string", icon="✅")
@@ -101,10 +113,10 @@ def convert_df(df):
101113
column = "SMILES"
102114
smiles_column = df["SMILES"]
103115
elif input == "Upload a CSV file":
104-
st.write("Upload a CSV file")
105-
106116
# Create a file uploader for CSV files
107-
uploaded_file = st.file_uploader("Choose a CSV file to upload your predictions to", type="csv")
117+
uploaded_file = st.file_uploader(
118+
"Choose a CSV file to upload your predictions to", type="csv"
119+
)
108120

109121
# If a file is uploaded, parse it into a DataFrame
110122
if uploaded_file is not None:
@@ -119,14 +131,20 @@ def convert_df(df):
119131
# check if the smiles are valid
120132
valid_smiles = [_is_valid_smiles(smi) for smi in smiles_column]
121133
if not all(valid_smiles):
122-
st.error("Some of the SMILES strings are invalid, please check the input", icon="🚨")
134+
st.error(
135+
"Some of the SMILES strings are invalid, please check the input", icon="🚨"
136+
)
123137
st.stop()
124-
st.success("All SMILES strings are valid, proceeding with prediction", icon="✅")
138+
st.success(
139+
f"All SMILES strings are valid (n={len(valid_smiles)}), proceeding with prediction",
140+
icon="✅",
141+
)
125142

126143
elif input == "Upload an SDF file":
127-
st.write("Upload an SDF file")
128144
# Create a file uploader for SDF files
129-
uploaded_file = st.file_uploader("Choose a SDF file to upload your predictions to", type="sdf")
145+
uploaded_file = st.file_uploader(
146+
"Choose a SDF file to upload your predictions to", type="sdf"
147+
)
130148
# read with rdkit
131149
if uploaded_file is not None:
132150
# To convert to a string based IO:
@@ -136,17 +154,20 @@ def convert_df(df):
136154
mols = sdf_str_to_rdkit_mol(string_data)
137155
smiles = [Chem.MolToSmiles(m) for m in mols]
138156
df = pd.DataFrame(smiles, columns=["SMILES"])
139-
# st.error("Error reading the SDF file, please check the input", icon="🚨")
140-
# st.stop()
157+
# st.error("Error reading the SDF file, please check the input", icon="🚨")
158+
# st.stop()
141159
else:
142160
st.stop()
143-
144-
st.success("All SMILES strings are valid, proceeding with prediction", icon="✅")
161+
162+
st.success(
163+
f"All molecule entries are valid (n={len(df)}), proceeding with prediction",
164+
icon="✅",
165+
)
145166
column = "SMILES"
146167
smiles_column = df["SMILES"]
147168
multismiles = True
148169

149-
st.markdown("## Select your model parameters")
170+
st.markdown("## Model parameters :nut_and_bolt:")
150171

151172

152173
targets = ASAPMLModelRegistry.get_targets_with_models()
@@ -167,71 +188,129 @@ def convert_df(df):
167188
_target = target_value
168189
_target_str = target_value
169190
# Get the latest model for the target and endpoint
170-
model = ASAPMLModelRegistry.get_latest_model_for_target_type_and_endpoint(_target, "GAT", endpoint_value)
191+
model = ASAPMLModelRegistry.get_latest_model_for_target_type_and_endpoint(
192+
_target, "GAT", endpoint_value
193+
)
171194
if model is None:
172195
st.write(f"No model found for {target_value} {endpoint_value}")
173196
st.stop()
174197
# retry with a different target or endpoint
175198

176-
st.markdown("## Prediction time 🚀")
199+
st.markdown("## Prediction 🚀")
177200

178201

179-
st.write(f"Predicting {_target_str} {endpoint_value} using model {model.name}")
202+
st.write(
203+
f"Predicting **{_target_str} {endpoint_value}** using model:\n\n `{model.name}`"
204+
)
180205
# Create a GATInference object from the model
181206
infr = GATInference.from_ml_model_spec(model)
182207
if infr.is_ensemble:
183-
st.write(f"Ensemble model with {len(model.models)} models, will estimate uncertainty using ensemble variance")
208+
st.write(
209+
f"_Using ensemble model (n={len(model.models)}); estimating uncertainty as variance of predictions._"
210+
)
184211
# Predict the property value for each SMILES string
185-
predictions = [infr.predict_from_smiles(smiles, return_err=True) for smiles in smiles_column]
212+
predictions = [
213+
infr.predict_from_smiles(smiles, return_err=True) for smiles in smiles_column
214+
]
186215
predictions = np.asarray(predictions)
187216
# check if second column is all np.nan
188217
if np.all(np.isnan(predictions[:, 1])):
189218
preds = predictions[:, 0]
190219
err = None
191220
else:
192221
preds = predictions[:, 0]
193-
err = predictions[:, 1] # rejoin with the original dataframe
222+
err = predictions[:, 1] # rejoin with the original dataframe
194223

195224

196-
df["predictions"] = preds
197-
df["prediction_error"] = err
198-
199-
# sort the dataframe by predictions
200-
df = df.sort_values(by="predictions", ascending=False)
225+
pred_column_name = f"{_target_str}_computed-{endpoint_value}"
226+
unc_column_name = f"{_target_str}_computed-{endpoint_value}_uncertainty"
227+
df[pred_column_name] = preds
228+
df[unc_column_name] = err
201229

230+
st.markdown("---")
202231
if multismiles:
203232
# plot the predictions and errors
204-
st.scatter_chart(df, x=column, y="predictions", color="prediction_error", use_container_width=True, x_label="SMILES", y_label=f"Predicted {_target_str} {endpoint_value} ")
233+
# Histogram first
234+
fig, ax = plt.subplots()
235+
236+
sorted_df = df.sort_values(by=pred_column_name)
237+
n_bins = int(len(sorted_df[pred_column_name]) / 10)
238+
if n_bins < 5: # makes the histogram slightly more interpretable with low data
239+
n_bins = 5
240+
241+
ax.hist(sorted_df[pred_column_name], bins=n_bins)
242+
243+
ax.set_ylabel("Count")
244+
ax.set_xlabel(f"Computed {endpoint_value}")
245+
ax.set_title(f"Histogram of computed {endpoint_value} for target: {_target_str}")
246+
247+
st.pyplot(fig)
248+
249+
# then a barplot
250+
fig, ax = plt.subplots()
251+
252+
ax.bar(range(len(sorted_df)), sorted_df[pred_column_name])
253+
254+
ax.set_xticks([])
255+
ax.set_xlabel(f"Query compounds")
256+
ax.set_ylabel(f"Computed {endpoint_value}")
257+
258+
ax.set_title(f"Barplot of computed {endpoint_value} for target: {_target_str}")
259+
260+
st.pyplot(fig)
261+
262+
if endpoint_value == "pIC50":
263+
from rdkit.Chem.Descriptors import MolWt
264+
import seaborn as sns
265+
266+
# then a scatterplot of uncertainty vs MW
267+
df["MW"] = [MolWt(Chem.MolFromSmiles(smi)) for smi in sorted_df["SMILES"]]
268+
fig, ax = plt.subplots()
269+
270+
ax = sns.scatterplot(
271+
x="MW", y=pred_column_name, hue=unc_column_name, palette="coolwarm", data=df
272+
)
273+
274+
norm = plt.Normalize(df[unc_column_name].min(), df[unc_column_name].max())
275+
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
276+
sm.set_array([])
277+
278+
# Remove the legend and add a colorbar
279+
ax.get_legend().remove()
280+
cbar = ax.figure.colorbar(sm, ax=ax)
281+
ax.annotate(
282+
f"Computed {endpoint_value} uncertainty",
283+
xy=(1.2, 0.3),
284+
xycoords="axes fraction",
285+
rotation=270,
286+
)
287+
288+
ax.set_title(
289+
f"Scatterplot of predicted {endpoint_value} versus MW\ntarget: {_target_str}"
290+
)
291+
ax.set_xlabel(f"Molecular weight (Da)")
292+
ax.set_ylabel(f"Computed {endpoint_value}")
293+
st.pyplot(fig)
205294

206295
else:
207296
# just print the prediction
208-
preds = df["predictions"].values[0]
297+
preds = df[pred_column_name].values[0]
209298
smiles = df["SMILES"].values[0]
210299
if err:
211-
err = df["prediction_error"].values[0]
300+
err = df[unc_column_name].values[0]
212301
errstr = f"± {err:.2f}"
213302
else:
214303
errstr = ""
215-
216-
st.markdown("### 🕵️")
217-
st.markdown(f"Predicted {_target_str} {endpoint_value} for {smiles} is {preds:.2f} {errstr} using model {infr.model_name}")
304+
305+
st.markdown(
306+
f"Predicted {_target_str} {endpoint_value} for {smiles} is {preds:.2f} {errstr}."
307+
)
218308

219309
# allow the user to download the predictions
220310
csv = convert_df(df)
221311
st.download_button(
222-
label="Download data as CSV",
223-
data=csv,
224-
file_name=f"predictions_{model.name}.csv",
225-
mime="text/csv",
226-
)
227-
228-
229-
230-
231-
232-
233-
234-
235-
236-
# else:
237-
# st.write("Please upload a CSV file to view its contents.")
312+
label="Download data as CSV",
313+
data=csv,
314+
file_name=f"predictions_{model.name}.csv",
315+
mime="text/csv",
316+
)

0 commit comments

Comments
 (0)