Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error:KeyError: 'xyz' #208

Open
exYuan opened this issue Jan 31, 2024 · 1 comment
Open

error:KeyError: 'xyz' #208

exYuan opened this issue Jan 31, 2024 · 1 comment

Comments

@exYuan
Copy link

exYuan commented Jan 31, 2024

Dear espaloma devs,

I have been researching learning graph neural networks recently and saw that your project ESPALOMA was replicated using your PHALKETHOH database. However, when I tried to replicate your ZINC dataset (which contains SDF files), I used a training script adapted from collection.py to read and convert SDFs. However, when adding the graph, I encountered an error stating that the dictionary xyz could not be found. I also encountered the same problem when using the PDBBIND dataset, After printing the dictionary for the key, it was found that there was no xyz. The image on the left shows phalkethoh imported as binary, while the image on the right shows the imported sdf of ligands in the pdbbind dataset. There is a lot of missing information on the right side of the graph
31f7ef01e6d6fd2a3ec5d05ef2412e1
e906ce7ed2c769455feceecbfaf55bb

I don't know how to properly write the code for importing from sdf files to espgraph. I can import all the parameters of the molecules. I tried the method in collection.py in the project, but it doesn't work. Can you please advise me on how to write it?
Best wishes

`import sys
from os.path import exists
import os
from openff.toolkit.topology import Molecule
from rdkit import Chem
from rdkit.Chem import AllChem
import espaloma as esp
import torch
import csv
firs=-1
count = 0
fi = 0
gs = []

定义解压目标文件夹

extract_folder = r'/home/twyuan/test/pdb/refined_set'

for root, dirs, files in os.walk(extract_folder):
for file in files:
# 确保文件是.sdf文件
if file.endswith('.sdf'):
sdf_file_path = os.path.join(root, file)

        try:
            # 使用rdkit读取.sdf文件
            _mols = Chem.ForwardSDMolSupplier(sdf_file_path, removeHs=False)
            
            # 遍历.sdf文件中的分子
            for mol in _mols:
                try:
                    
                    # 尝试修复分子结构
                    mol = Chem.Mol(mol)
                    #AllChem.Compute2DCoords(mol)  # 计算2D坐标
                    Chem.Kekulize(mol, clearAromaticFlags=True)  # 进行芳香性规范化
                    mol = Chem.AddHs(mol)  # 添加氢原子
                    Chem.SanitizeMol(mol)  # 检查和修复分子结构的一般问题 
                    
                    
                    #print(f'Molecule XYZ coordinates: {Chem.MolToMolBlock(mol)}')
                    gs.append(
                        esp.Graph(
                            Molecule.from_rdkit(mol, allow_undefined_stereo=True)
                        )
                    )
                    #print (esp.graphs.graph.__dict__)
                    #sys.exit()
                    '''
                    smi = Chem.MolToSmiles(mol)
                    gs.append(
                        esp.Graph(
                            Molecule.from_smiles(smi, allow_undefined_stereo=True)
                        )
                    )
                    '''
                    '''
                    esp_graph = esp.Graph(Molecule.from_rdkit(mol, allow_undefined_stereo=True))
                    if 'xyz' not in esp_graph.nodes[0].data:
                        print(f"Warning: 'xyz' not found in esp_graph.nodes[0].data for molecule {mol.GetProp('_Name')}")
                    else:
                        gs.append(esp_graph)

                    count += 1
                    '''

                except Exception as e:
                    # 输出异常信息
                    print(f"Error processing molecule in file {sdf_file_path}: {e}")
                    # 跳过当前分子,继续处理下一个分子
                    continue

                if firs != -1 and count >= firs:
                    break
                # 在这里进行你的处理,比如输出分子信息等

                if mol is not None:
                    fi += 1
                    print(f'Molecule name: {mol.GetProp("_Name")}')
                    # 这里可以继续处理其他分子信息
            
            # 在这里添加打印语句
            print("Number of esp.Graph objects in gs:", len(gs))

        except Exception as e_file:
            # 输出异常信息
            print(f"Error processing file {sdf_file_path}: {e_file}")
            # 跳过当前文件,继续处理下一个文件
            continue

print(fi)
print(len(gs))

ds = esp.data.dataset.GraphDataset(gs)
#print (g.dict)
ds.shuffle(seed=2666)
ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])

print(len(ds_tr)) # 打印数据集中的样本数量
print(len(ds_vl))
print(len(ds_te))
ds_tr_loader = ds_tr.view(batch_size=100, shuffle=True)
g_tr = next(iter(ds_tr.view(batch_size=len(ds_tr))))
g_vl = next(iter(ds_vl.view(batch_size=len(ds_vl))))
representation = esp.nn.Sequential(
layer=esp.nn.layers.dgl_legacy.gn("SAGEConv"), # use SAGEConv implementation in DGL
config=[128, "relu", 128, "relu", 128, "relu"], # 3 layers, 128 units, ReLU activation
)
readout = esp.nn.readout.janossy.JanossyPooling(
in_features=128, config=[128, "relu", 128, "relu", 128, "relu"],
out_features={ # define modular MM parameters Espaloma will assign
1: {"e": 1, "s": 1}, # atom hardness and electronegativity
2: {"log_coefficients": 2}, # bond linear combination, enforce positive
3: {"log_coefficients": 2}, # angle linear combination, enforce positive
4: {"k": 6}, # torsion barrier heights (can be positive or negative)
},
)

espaloma_model = torch.nn.Sequential(
representation, readout, esp.nn.readout.janossy.ExpCoefficients(),
esp.mm.geometry.GeometryInGraph(),
esp.mm.energy.EnergyInGraph(),
esp.mm.energy.EnergyInGraph(suffix="_ref"),
esp.nn.readout.charge_equilibrium.ChargeEquilibrium(),
)
if torch.cuda.is_available():
espaloma_model = espaloma_model.cuda()

loss_fn = esp.metrics.GraphMetric(
base_metric=torch.nn.MSELoss(), # use mean-squared error loss
between=['u', "u_ref"], # between predicted and QM energies
level="g", # compare on graph level
)

optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)

for idx_epoch in range(10):
for g in ds_tr_loader:
optimizer.zero_grad()
if torch.cuda.is_available():
g = g.to("cuda:0")
g = espaloma_model(g)
loss = loss_fn(g)
loss.backward()
optimizer.step()
torch.save(espaloma_model.state_dict(), "%s.th" % idx_epoch)
# 保存模型参数
torch.save(espaloma_model.state_dict(), f"/home/twyuan/test/4/{idx_epoch}.th")

    # 输出学习率和损失
    current_lr = optimizer.param_groups[0]["lr"]
    current_loss = loss.item()
    print(f"Epoch {idx_epoch + 1}, Learning Rate: {current_lr}, Loss: {current_loss}")

    # 保存学习率和损失
    learning_rates.append(current_lr)
    losses.append(current_loss)

画学习率和损失曲线图

plt.plot(learning_rates, label="Learning Rate")
plt.plot(losses, label="Loss")
plt.xlabel("Epoch")
plt.legend()
plt.show()
plt.savefig('/home/twyuan/test/4/learning_curve.png')

inspect_metric = esp.metrics.GraphMetric(
base_metric=torch.nn.L1Loss(), # use mean-squared error loss
between=['u', "u_ref"], # between predicted and QM energies
level="g", # compare on graph level
)

if torch.cuda.is_available():
g_vl = g_vl.to("cuda:0")
g_tr = g_tr.to("cuda:0")

loss_tr = []
loss_vl = []

for idx_epoch in range(10):
espaloma_model.load_state_dict(
torch.load("%s.th" % idx_epoch)
)

espaloma_model(g_tr)
loss_tr.append(inspect_metric(g_tr).item())

espaloma_model(g_vl)
loss_vl.append(inspect_metric(g_vl).item())

import numpy as np
loss_tr = np.array(loss_tr) * 627.5
loss_vl = np.array(loss_vl) * 627.5

from matplotlib import pyplot as plt
plt.plot(loss_tr, label="train")
plt.plot(loss_vl, label="valid")
plt.yscale("log")
plt.legend()
plt.show()
plt.savefig('/home/twyuan/test/4/train_valid.png')

import torch
from openff.toolkit.topology import Molecule
from rdkit import Chem
import espaloma as esp
from espaloma.data.collection import zinc
import dgl
import numpy as np

def main():
# 假设你的数据文件在当前工作目录下,文件名为 data.zinc
ds = zinc(first=10) # 仅加载前10个分子
for item in ds:
print(item)
# 处理返回的 GraphDataset 对象
if isinstance(ds, esp.data.dataset.GraphDataset):
print("GraphDataset loaded successfully!")
# 这里可以根据需要使用 result 进行进一步的操作

    typing = esp.graphs.legacy_force_field.LegacyForceField('gaff-1.81')
    ds.apply(typing, in_place=True) # this modify the original data

    ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])
    ds_tr = ds_tr.view('graph', batch_size=100, shuffle=True)
    ds_te = ds_te.view('graph', batch_size=100)
    ds_vl = ds_vl.view('graph', batch_size=100)
    # define a layer
    layer = esp.nn.layers.dgl_legacy.gn("SAGEConv")

    # define a representation
    representation = esp.nn.Sequential(
            layer,
            [128, "relu", 128, "relu", 128, "relu"],
    )

    # define a readout
    readout = esp.nn.readout.node_typing.NodeTyping(
            in_features=128,
            n_classes=100
    )

    net = torch.nn.Sequential(
        representation,
        readout
    )

    loss_fn = esp.metrics.TypingAccuracy()

    # define optimizer
    optimizer = torch.optim.Adam(net.parameters(), 1e-5)

    # train the model
    for _ in range(3000):
        for g in ds_tr:
            optimizer.zero_grad()
            net(g.heterograph)
            loss = loss_fn(g.heterograph)
            loss.backward()
            optimizer.step()

if name == "main":
main()
`

@mikemhenry
Copy link
Contributor

I don't know if this helps but someone else ran into this issue and made a PR #203

The changes didn't get merged in but maybe you would find the code useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants