-
Notifications
You must be signed in to change notification settings - Fork 31
/
preprocessor.py
164 lines (149 loc) · 8.52 KB
/
preprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import numpy as np
from kgcnn.molecule.graph_rdkit import MolecularGraphRDKit
from kgcnn.graph.base import GraphPreProcessorBase
from kgcnn.molecule.methods import inverse_global_proton_dict
from kgcnn.molecule.io import parse_list_to_xyz_str
from kgcnn.molecule.encoder import OneHotEncoder
from kgcnn.utils.serial import serialize
from kgcnn.molecule.serial import deserialize_encoder
_mol_graph_interface = MolecularGraphRDKit
class SetMolBondIndices(GraphPreProcessorBase):
r"""Preprocessor to compute chemical bonds from coordinates via a :obj:`MolGraphInterface` .
Args:
node_coordinates (str): Name of atomic coordinates array of shape `(N, 3)` .
node_symbol (str): Name of atomic symbol as numpy array of shape `(N, )` .
node_number (str): Name of atomic numbers array of shape `(N, )` .
edge_indices (str): Name to assign edge indices to.
edge_number (str): Name to assign the edge number/order to.
name (str): Name of this preprocessor.
"""
def __init__(self, *, node_coordinates: str = "node_coordinates", node_symbol: str = "node_symbol",
node_number: str = "node_number",
edge_indices: str = "edge_indices", edge_number: str = "edge_number",
name="set_mol_bond_indices", **kwargs):
super().__init__(name=name, **kwargs)
self._to_obtain.update({"node_coordinates": node_coordinates, "node_number": node_number,
"node_symbol": node_symbol})
self._to_assign = [edge_indices, edge_number]
self._config_kwargs.update({
"edge_indices": edge_indices, "node_coordinates": node_coordinates, "node_number": node_number,
"node_symbol": node_symbol, "edge_number": edge_number})
def call(self, node_coordinates: np.ndarray, node_symbol: np.ndarray, node_number: np.ndarray):
if node_symbol is None:
node_symbol = [inverse_global_proton_dict(int(x)) for x in node_number]
else:
node_symbol = [str(x) for x in node_symbol]
mol = _mol_graph_interface()
mol = mol.from_xyz(parse_list_to_xyz_str([node_symbol, node_coordinates.tolist()], number_coordinates=3))
if mol.mol is None:
return None, None
idx, edge_num = mol.edge_number
return idx, edge_num
class SetMolAttributes(GraphPreProcessorBase):
"""Preprocessor to compute molecular attributes from graph arrays that make a valid molecule
via a :obj:`MolGraphInterface` .
See :obj:`MoleculeNetDataset` which uses a callbacks but has identical nomenclature.
.. code-block:: python
from kgcnn.data.datasets.QM7Dataset import QM7Dataset
from kgcnn.molecule.preprocessor import SetMolAttributes
ds = QM7Dataset()
pp = SetMolAttributes()
print(pp(ds[0]))
Args:
nodes (list): List of atomic properties for attributes.
edges (list): List of bond properties for attributes.
graph (list): List of molecular properties for attributes.
encoder_nodes (dict): Dictionary of node attribute encoders.
encoder_edges (dict): Dictionary of edge attribute encoders.
encoder_graph (dict): Dictionary of graph attribute encoders.
node_coordinates (str): Name of numpy array storing atomic coordinates.
node_symbol (str): Name of numpy array storing atomic symbol.
node_number (str): Name of numpy array storing atomic number.
edge_indices (str): Name of numpy array storing atomic bond indices.
edge_number (str): Name of numpy array storing atomic bond order.
node_attributes (str): Name to assign node attributes to.
edge_attributes (str): Name to assign edge attributes to.
graph_attributes (str): Name to assign graph attributes to.
name (str): Name of the preprocessor.
"""
_default_node_attributes = [
'Symbol', 'TotalDegree', 'FormalCharge', 'NumRadicalElectrons', 'Hybridization',
'IsAromatic', 'IsInRing', 'TotalNumHs', 'CIPCode', "ChiralityPossible", "ChiralTag"
]
_default_node_encoders = {
'Symbol': OneHotEncoder(
['B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
dtype="str"
),
'Hybridization': OneHotEncoder([2, 3, 4, 5, 6]),
'TotalDegree': OneHotEncoder([0, 1, 2, 3, 4, 5], add_unknown=False),
'TotalNumHs': OneHotEncoder([0, 1, 2, 3, 4], add_unknown=False),
'CIPCode': OneHotEncoder(['R', 'S'], add_unknown=False, dtype='str'),
"ChiralityPossible": OneHotEncoder(["1"], add_unknown=False, dtype='str'),
}
_default_edge_attributes = ['BondType', 'IsAromatic', 'IsConjugated', 'IsInRing', 'Stereo']
_default_edge_encoders = {
'BondType': OneHotEncoder([1, 2, 3, 12], add_unknown=False),
'Stereo': OneHotEncoder([0, 1, 2, 3], add_unknown=False)
}
_default_graph_attributes = ['ExactMolWt', 'NumAtoms']
_default_graph_encoders = {}
def __init__(self, *,
nodes: list = None, edges: list = None, graph: list = None,
encoder_nodes: dict = None,
encoder_edges: dict = None,
encoder_graph: dict = None,
node_coordinates: str = "node_coordinates", node_symbol: str = "node_symbol",
node_number: str = "node_number",
edge_indices: str = "edge_indices", edge_number: str = "edge_number",
node_attributes: str = "node_attributes", edge_attributes: str = "edge_attributes",
graph_attributes: str = "graph_attributes",
name="set_mol_attributes", **kwargs):
super().__init__(name=name, **kwargs)
nodes = nodes if nodes is not None else self._default_node_attributes
edges = edges if edges is not None else self._default_edge_attributes
graph = graph if graph is not None else self._default_graph_attributes
encoder_nodes = encoder_nodes if encoder_nodes is not None else self._default_node_encoders
encoder_edges = encoder_edges if encoder_edges is not None else self._default_edge_encoders
encoder_graph = encoder_graph if encoder_graph is not None else self._default_graph_encoders
self._to_obtain.update({"node_coordinates": node_coordinates, "node_number": node_number,
"node_symbol": node_symbol, "edge_indices": edge_indices, "edge_number": edge_number})
self._to_assign = [node_attributes, edge_attributes, graph_attributes, edge_indices, edge_number]
self._call_kwargs = {
"nodes": nodes,
"edges": edges,
"graph": graph,
"encoder_nodes": {key: deserialize_encoder(value) for key, value in encoder_nodes.items()},
"encoder_edges": {key: deserialize_encoder(value) for key, value in encoder_edges.items()},
"encoder_graph": {key: deserialize_encoder(value) for key, value in encoder_graph.items()}
}
self._config_kwargs.update({
"edge_indices": edge_indices, "node_coordinates": node_coordinates, "node_number": node_number,
"node_symbol": node_symbol, "edge_number": edge_number,
"node_attributes": node_attributes, "edge_attributes": edge_attributes,
"graph_attributes": graph_attributes,
"nodes": nodes,
"edges": edges,
"graph": graph,
"encoder_nodes": {key: serialize(value) for key, value in encoder_nodes.items()},
"encoder_edges": {key: serialize(value) for key, value in encoder_edges.items()},
"encoder_graph": {key: serialize(value) for key, value in encoder_graph.items()}
})
def call(self,
nodes: list, edges: list, graph: list,
encoder_nodes: dict,
encoder_edges: dict,
encoder_graph: dict,
node_coordinates: np.ndarray, node_symbol: np.ndarray, node_number: np.ndarray,
edge_indices: np.ndarray, edge_number: np.ndarray):
if node_symbol is None:
node_symbol = [inverse_global_proton_dict(int(x)) for x in node_number]
else:
node_symbol = [str(x) for x in node_symbol]
mol = _mol_graph_interface()
mol.from_list(node_symbol, edge_indices, edge_number, conformer=node_coordinates)
n_att = mol.node_attributes(nodes, encoder=encoder_nodes)
_, e_att = mol.edge_attributes(edges, encoder=encoder_edges)
g_att = mol.graph_attributes(graph, encoder=encoder_graph)
idx, en = mol.edge_number
return n_att, e_att, g_att, idx, en