Skip to content

Commit

Permalink
Merge pull request #3 from asapdiscovery/fix_smiles_issue
Browse files Browse the repository at this point in the history
fix SMILES parser downstream
  • Loading branch information
hmacdope authored Sep 26, 2024
2 parents 44b4800 + f0817fb commit 3583f87
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def convert_df(df):
st.error("Invalid molecule", icon="🚨")
st.stop()
smiles = [smiles]
df = pd.DataFrame(smiles, columns=["SMILES"])
column = "SMILES"
smiles_column = df["SMILES"]
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
elif input == "Enter SMILES":
smiles = st.text_input("Enter a SMILES string")
if _is_valid_smiles(smiles):
Expand All @@ -89,9 +89,9 @@ def convert_df(df):
st.error("Invalid SMILES string", icon="🚨")
st.stop()
smiles = [smiles]
df = pd.DataFrame(smiles, columns=["SMILES"])
column = "SMILES"
smiles_column = df["SMILES"]
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
elif input == "Upload a CSV file":
# Create a file uploader for CSV files
uploaded_file = st.file_uploader(
Expand All @@ -100,13 +100,13 @@ def convert_df(df):

# If a file is uploaded, parse it into a DataFrame
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
queried_df = pd.read_csv(uploaded_file)
else:
st.stop()
# Select a column from the DataFrame
column = st.selectbox("Select a column of SMILES analyze", df.columns)
smiles_column_name = st.selectbox("Select a SMILES column", queried_df.columns)
multismiles = True
smiles_column = df[column]
smiles_column = queried_df[smiles_column_name]

# check if the smiles are valid
valid_smiles = [_is_valid_smiles(smi) for smi in smiles_column]
Expand All @@ -133,18 +133,18 @@ def convert_df(df):
string_data = stringio.read()
mols = sdf_str_to_rdkit_mol(string_data)
smiles = [Chem.MolToSmiles(m) for m in mols]
df = pd.DataFrame(smiles, columns=["SMILES"])
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
# st.error("Error reading the SDF file, please check the input", icon="🚨")
# st.stop()
else:
st.stop()

st.success(
f"All molecule entries are valid (n={len(df)}), proceeding with prediction",
f"All molecule entries are valid (n={len(queried_df)}), proceeding with prediction",
icon="✅",
)
column = "SMILES"
smiles_column = df["SMILES"]
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
multismiles = True

st.markdown("## Model parameters :nut_and_bolt:")
Expand Down Expand Up @@ -204,16 +204,16 @@ def convert_df(df):

pred_column_name = f"{_target_str}_computed-{endpoint_value}"
unc_column_name = f"{_target_str}_computed-{endpoint_value}_uncertainty"
df[pred_column_name] = preds
df[unc_column_name] = err
queried_df[pred_column_name] = preds
queried_df[unc_column_name] = err

st.markdown("---")
if multismiles:
# plot the predictions and errors
# Histogram first
fig, ax = plt.subplots()

sorted_df = df.sort_values(by=pred_column_name)
sorted_df = queried_df.sort_values(by=pred_column_name)
n_bins = int(len(sorted_df[pred_column_name]) / 10)
if n_bins < 5: # makes the histogram slightly more interpretable with low data
n_bins = 5
Expand Down Expand Up @@ -244,14 +244,22 @@ def convert_df(df):
import seaborn as sns

# then a scatterplot of uncertainty vs MW
df["MW"] = [MolWt(Chem.MolFromSmiles(smi)) for smi in sorted_df["SMILES"]]
queried_df["MW"] = [
MolWt(Chem.MolFromSmiles(smi)) for smi in sorted_df[smiles_column_name]
]
fig, ax = plt.subplots()

ax = sns.scatterplot(
x="MW", y=pred_column_name, hue=unc_column_name, palette="coolwarm", data=df
x="MW",
y=pred_column_name,
hue=unc_column_name,
palette="coolwarm",
data=queried_df,
)

norm = plt.Normalize(df[unc_column_name].min(), df[unc_column_name].max())
norm = plt.Normalize(
queried_df[unc_column_name].min(), queried_df[unc_column_name].max()
)
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])

Expand All @@ -274,10 +282,10 @@ def convert_df(df):

else:
# just print the prediction
preds = df[pred_column_name].values[0]
smiles = df["SMILES"].values[0]
preds = queried_df[pred_column_name].values[0]
smiles = queried_df["SMILES"].values[0]
if err:
err = df[unc_column_name].values[0]
err = queried_df[unc_column_name].values[0]
errstr = f"± {err:.2f}"
else:
errstr = ""
Expand All @@ -287,7 +295,7 @@ def convert_df(df):
)

# allow the user to download the predictions
csv = convert_df(df)
csv = convert_df(queried_df)
st.download_button(
label="Download data as CSV",
data=csv,
Expand Down

0 comments on commit 3583f87

Please sign in to comment.