Skip to content

Commit

Permalink
Update and test models.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Dec 11, 2023
1 parent e223723 commit 82fde0f
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ You can find a [table](training/results/README.md) of common benchmark datasets

Some known issues to be aware of, if using and making new models or layers with `kgcnn`.
* Jagged or nested Tensors loading into models for PyTorch backend is not working.
* Dataloader for Jax is not yet implemented.

<a name="citing"></a>
# Citing
Expand Down
66 changes: 61 additions & 5 deletions kgcnn/graph/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,13 @@ def call(self, *, node_coordinates: np.ndarray, graph_lattice: np.ndarray):


class CountNodesAndEdges(GraphPreProcessorBase):
r""".
r"""Count the number of nodes and edges.
Args:
node_attributes (str):
edge_indices (str):
count_node (str):
count_edge (str):
node_attributes (str): Name for nodes attributes to count nodes.
edge_indices (str): Name for edge_indices to count edges.
count_node (str): Name to assign node count to.
count_edge (str): Name to assign edge count to.
"""

def __init__(self, *, total_nodes: str = "total_nodes", total_edges: str = "total_edges",
Expand All @@ -578,3 +578,59 @@ def call(self, *, count_nodes: np.ndarray, count_edges: np.ndarray):
total_nodes = len(count_nodes) if count_nodes is not None else None
total_edges = len(count_edges) if count_edges is not None else None
return total_nodes, total_edges


class MakeDenseAdjacencyMatrix(GraphPreProcessorBase):
r"""Make adjacency matrix based on edge indices.
Args:
edge_indices (str): Name of adjacency matrix.
edge_attributes (str): Name of edge attributes.
adjacency_matrix (str): Name of adjacency matrix to assign output.
"""

def __init__(self, *, edge_indices: str = "edge_indices", edge_attributes: str = "edge_attributes",
adjacency_matrix: str = "adjacency_matrix",
name="make_dense_adjacency_matrix", **kwargs):
super().__init__(name=name, **kwargs)
self._to_obtain.update({"edge_indices": edge_indices, "edge_attributes": edge_attributes})
self._to_assign = adjacency_matrix
self._config_kwargs.update({"adjacency_matrix": adjacency_matrix, "edge_indices": edge_indices,
"edge_attributes": edge_attributes})

def call(self, *, edge_indices: np.ndarray, edge_attributes: np.ndarray):
if edge_indices is None or edge_attributes is None:
return None
if len(edge_indices) == 0:
return None
max_index = np.amax(edge_indices)+1
adj = np.zeros([max_index, max_index] + list(edge_attributes.shape[1:]))
adj[edge_indices[:, 0], edge_indices[:, 1]] = edge_attributes
return adj


class MakeMask(GraphPreProcessorBase):
r"""Make simple equally sized (dummy) mask for a graph property.
Args:
target_property (str): Name of the property to make mask for.
mask_name (str): Name of mask to assign output.
rank (int): Rank of mask.
"""

def __init__(self, *, target_property: str = "node_attributes", mask_name: str = None,
rank=1,
name="make_mask", **kwargs):
super().__init__(name=name, **kwargs)
if mask_name is None:
mask_name = target_property+"_mask"
self._to_obtain.update({"target_property": target_property})
self._to_assign = mask_name
self._call_kwargs = {"rank": rank}
self._config_kwargs.update({"mask_name": mask_name, "target_property": target_property, "rank": rank})

def call(self, *, target_property: np.ndarray, rank: int):
if target_property is None:
return None
mask = np.ones(target_property.shape[:rank], dtype="bool")
return mask
4 changes: 3 additions & 1 deletion kgcnn/graph/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def get_preprocessor(name: Union[str, dict], **kwargs):
"expand_distance_gaussian_basis": "ExpandDistanceGaussianBasis",
"atomic_charge_representation": "AtomicChargesRepresentation",
"principal_moments_of_inertia": "PrincipalMomentsOfInertia",
"count_nodes_and_edges": "CountNodesAndEdges"
"count_nodes_and_edges": "CountNodesAndEdges",
"make_dense_adjacency_matrix": "MakeDenseAdjacencyMatrix",
"make_mask": "MakeMask"
}
if isinstance(name, dict):
return deserialize(name)
Expand Down
6 changes: 6 additions & 0 deletions kgcnn/literature/MAT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._make import make_model, model_default

__all__ = [
"make_model",
"model_default"
]
10 changes: 7 additions & 3 deletions kgcnn/literature/MAT/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def call(self, inputs, mask=None, **kwargs):
diff = ops.expand_dims(inputs, axis=1) - ops.expand_dims(inputs, axis=2)
dist = ops.sum(ops.square(diff), axis=-1, keepdims=True)
# shape of dist (batch, N, N, 1)
mask = ops.cast(mask, dtype="float32")
diff_mask = ops.expand_dims(mask, axis=1) * ops.expand_dims(mask, axis=2)
dist_mask = ops.prod(diff_mask, axis=-1, keepdims=True)

Expand Down Expand Up @@ -101,7 +102,8 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: Product of inputs along axis.
"""
return ops.prod(inputs, keepdims=self.keepdims, axis=self.axis)
out = ops.prod(inputs, keepdims=self.keepdims, axis=self.axis)
return out

def get_config(self):
config = super(MATReduceMask, self).get_config()
Expand All @@ -127,7 +129,8 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: Input tensor with expanded axis.
"""
return ops.expand_dims(inputs, axis=self.axis)
out = ops.expand_dims(inputs, axis=self.axis)
return out

def get_config(self):
config = super(MATExpandMask, self).get_config()
Expand Down Expand Up @@ -166,7 +169,7 @@ def call(self, inputs, mask=None, **kwargs):
r"""Forward pass.
Args:
inputs (list): List of [h_n, A_d, A_g] represented by padded :obj:`tf.Tensor` .
inputs (list): List of [h_n, A_d, A_g] represented by padded :obj:`Tensor` .
These are node features and adjacency matrix from distances and bonds or bond order.
mask (list): Mask tensors matching inputs, i.e. a mask tensor for each padded input.
Expand All @@ -175,6 +178,7 @@ def call(self, inputs, mask=None, **kwargs):
"""
h, a_d, a_g = inputs
h_mask, a_d_mask, a_g_mask = mask
h_mask = ops.cast(h_mask, dtype=h.dtype)
q = ops.expand_dims(self.dense_q(h), axis=2)
k = ops.expand_dims(self.dense_k(h), axis=1)
v = self.dense_v(h) * h_mask
Expand Down
35 changes: 22 additions & 13 deletions kgcnn/literature/MAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
"name": "MAT",
"inputs": [
{"shape": (None,), "name": "node_number", "dtype": "int64"},

{"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"},
{"shape": (None, None, 1), "name": "adjacency_matrix", "dtype": "float32"},
{"shape": (None,), "name": "node_mask", "dtype": "bool"},
{"shape": (None, None), "name": "adjacency_mask", "dtype": "bool"},
],
"input_embedding": {"node": {"input_dim": 95, "output_dim": 64},
"edge": {"input_dim": 95, "output_dim": 64}},
"use_edge_embedding": False,
"input_tensor_type": "padded",
"input_embedding": None,
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": {"input_dim": 95, "output_dim": 64},
"max_atoms": None,
"distance_matrix_kwargs": {"trafo": "exp"},
"attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3, "lambda_adjacency": None,
Expand All @@ -43,30 +47,34 @@
"verbose": 10,
"pooling_kwargs": {"pooling_method": "sum"},
"output_embedding": "graph",
"output_to_tensor": True,
"output_to_tensor": None,
"output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1],
"activation": ["relu", "relu", "linear"]}
"activation": ["relu", "relu", "linear"]},
"output_tensor_type": "padded"
}


@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"])
def make_model(name: str = None,
inputs: list = None,
input_embedding: dict = None, # noqa
input_node_embedding: dict = None,
input_tensor_type: str = None,
input_edge_embedding: dict = None,
distance_matrix_kwargs: dict = None,
attention_kwargs: dict = None,
max_atoms: int = None,
feed_forward_kwargs:dict = None,
embedding_units: int = None,
depth: int = None,
heads: int = None,
merge_heads: str = None,
verbose: int = None,
verbose: int = None, # noqa
pooling_kwargs: dict = None,
output_embedding: str = None,
output_to_tensor: bool = None,
output_mlp: dict = None
output_to_tensor: bool = None, # noqa
output_mlp: dict = None,
output_tensor_type: str = None
):
r"""Make `MAT <https://arxiv.org/pdf/2002.08264.pdf>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.MAT.model_default` .
Expand All @@ -84,7 +92,7 @@ def make_model(name: str = None,
- adjacency_matrix (Tensor): Edge attributes of shape `(batch, N, N, F)` or `(batch, N, N)`
using an embedding layer.
- node_mask (Tensor): Node mask of shape `(batch, N)`
- adjacency_mask (Tensor): Adjacency mask of shape `(batch, N)`
- adjacency_mask (Tensor): Adjacency mask of shape `(batch, N, N)`
Outputs:
Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
Expand Down Expand Up @@ -114,6 +122,7 @@ def make_model(name: str = None,
:obj:`keras.models.Model`
"""
assert input_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation."
assert output_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation."
# Make input
node_input = ks.layers.Input(**inputs[0])
xyz_input = ks.layers.Input(**inputs[1])
Expand All @@ -135,6 +144,8 @@ def make_model(name: str = None,
n_mask = node_mask
adj_mask = adjacency_mask
xyz = xyz_input
n_mask = MATExpandMask(axis=-1)(n_mask)
adj_mask = MATExpandMask(axis=-1)(adj_mask)

# Cast to dense Tensor with padding for MAT.
# Nodes must have feature dimension.
Expand All @@ -152,11 +163,9 @@ def make_model(name: str = None,
if has_edge_dim:
# Assume that feature-wise attention is not desired for adjacency, reduce to single value.
adj = ks.layers.Dense(1, use_bias=False)(adj)
adj_mask = MATReduceMask(axis=-1, keepdims=True)(adj_mask)
else:
# Make sure that shape is (batch, max_atoms, max_atoms, 1).
adj = MATExpandMask(axis=-1)(adj)
adj_mask = MATExpandMask(axis=-1)(adj_mask)

# Repeat for depth.
h_mask = n_mask
Expand Down Expand Up @@ -202,7 +211,7 @@ def make_model(name: str = None,
raise ValueError("Unsupported graph embedding for mode `MAT` .")

model = ks.models.Model(
inputs=[node_input, xyz_input, adjacency_matrix],
inputs=[node_input, xyz_input, adjacency_matrix, node_mask, adjacency_mask],
outputs=out,
name=name
)
Expand Down
Empty file removed kgcnn/literature/MAT/_model.py
Empty file.
81 changes: 81 additions & 0 deletions training/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,4 +1586,85 @@
"kgcnn_version": "4.0.0"
}
},
"MAT": {
"model": {
"class_name": "make_model",
"module_name": "kgcnn.literature.MAT",
"config": {
"name": "MAT",
"inputs": [
{"shape": (None,), "name": "node_number", "dtype": "int64"},
{"shape": (None, 3), "name": "node_coordinates", "dtype": "float64"},
{"shape": (None, None, 11), "name": "adjacency_matrix", "dtype": "float64"},
{"shape": (None,), "name": "node_mask", "dtype": "bool"},
{"shape": (None, None), "name": "adjacency_mask", "dtype": "bool"},
],
"input_embedding": None,
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": None,
"distance_matrix_kwargs": {"trafo": "exp"},
"attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3,
"lambda_adjacency": None, "add_identity": False,
"dropout": 0.1},
"feed_forward_kwargs": {"units": [32, 32, 32], "activation": ["relu", "relu", "linear"]},
"embedding_units": 32,
"depth": 5,
"heads": 8,
"merge_heads": "concat",
"verbose": 10,
"pooling_kwargs": {"pooling_method": "sum"},
"output_embedding": "graph",
"output_to_tensor": True,
"output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1],
"activation": ["relu", "relu", "linear"]}
}
},
"training": {
"fit": {
"batch_size": 32,
"epochs": 400,
"validation_freq": 10,
"verbose": 2,
"callbacks": [
{
"class_name": "kgcnn>LinearLearningRateScheduler",
"config": {
"learning_rate_start": 5e-04, "learning_rate_stop": 1e-05, "epo_min": 0, "epo": 400,
"verbose": 0
}
}
]
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 5e-04}},
"loss": "mean_absolute_error"
},
"cross_validation": {"class_name": "KFold",
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
"scaler": {"class_name": "StandardLabelScaler",
"config": {"with_std": True, "with_mean": True, "copy": True}}
},
"data": {
"dataset": {
"class_name": "ESOLDataset",
"module_name": "kgcnn.data.datasets.ESOLDataset",
"config": {},
"methods": [
{"set_attributes": {}},
{"map_list": {"method": "set_edge_weights_uniform"}},
{"map_list": {"method": "make_dense_adjacency_matrix"}},
{"map_list": {"method": "make_mask", "target_property": "node_number",
"mask_name": "node_mask", "rank": 1}},
{"map_list": {"method": "make_mask", "target_property": "adjacency_matrix",
"mask_name": "adjacency_mask", "rank": 2}}
]
},
"data_unit": "mol/L"
},
"info": {
"postfix": "",
"postfix_file": "",
"kgcnn_version": "4.0.0"
}
},
}

0 comments on commit 82fde0f

Please sign in to comment.