From 82fde0f2751df35942ed5fcc99413fae41ef0774 Mon Sep 17 00:00:00 2001 From: PatReis Date: Mon, 11 Dec 2023 17:56:32 +0100 Subject: [PATCH] Update and test models. --- README.md | 1 + kgcnn/graph/preprocessor.py | 66 ++++++++++++++++++++++++-- kgcnn/graph/serial.py | 4 +- kgcnn/literature/MAT/__init__.py | 6 +++ kgcnn/literature/MAT/_layers.py | 10 ++-- kgcnn/literature/MAT/_make.py | 35 +++++++++----- kgcnn/literature/MAT/_model.py | 0 training/hyper/hyper_esol.py | 81 ++++++++++++++++++++++++++++++++ 8 files changed, 181 insertions(+), 22 deletions(-) delete mode 100644 kgcnn/literature/MAT/_model.py diff --git a/README.md b/README.md index 21d4be85..1f60c312 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,7 @@ You can find a [table](training/results/README.md) of common benchmark datasets Some known issues to be aware of, if using and making new models or layers with `kgcnn`. * Jagged or nested Tensors loading into models for PyTorch backend is not working. +* Dataloader for Jax is not yet implemented. # Citing diff --git a/kgcnn/graph/preprocessor.py b/kgcnn/graph/preprocessor.py index bd8d9d98..2f172dee 100644 --- a/kgcnn/graph/preprocessor.py +++ b/kgcnn/graph/preprocessor.py @@ -556,13 +556,13 @@ def call(self, *, node_coordinates: np.ndarray, graph_lattice: np.ndarray): class CountNodesAndEdges(GraphPreProcessorBase): - r""". + r"""Count the number of nodes and edges. Args: - node_attributes (str): - edge_indices (str): - count_node (str): - count_edge (str): + node_attributes (str): Name for nodes attributes to count nodes. + edge_indices (str): Name for edge_indices to count edges. + count_node (str): Name to assign node count to. + count_edge (str): Name to assign edge count to. """ def __init__(self, *, total_nodes: str = "total_nodes", total_edges: str = "total_edges", @@ -578,3 +578,59 @@ def call(self, *, count_nodes: np.ndarray, count_edges: np.ndarray): total_nodes = len(count_nodes) if count_nodes is not None else None total_edges = len(count_edges) if count_edges is not None else None return total_nodes, total_edges + + +class MakeDenseAdjacencyMatrix(GraphPreProcessorBase): + r"""Make adjacency matrix based on edge indices. + + Args: + edge_indices (str): Name of adjacency matrix. + edge_attributes (str): Name of edge attributes. + adjacency_matrix (str): Name of adjacency matrix to assign output. + """ + + def __init__(self, *, edge_indices: str = "edge_indices", edge_attributes: str = "edge_attributes", + adjacency_matrix: str = "adjacency_matrix", + name="make_dense_adjacency_matrix", **kwargs): + super().__init__(name=name, **kwargs) + self._to_obtain.update({"edge_indices": edge_indices, "edge_attributes": edge_attributes}) + self._to_assign = adjacency_matrix + self._config_kwargs.update({"adjacency_matrix": adjacency_matrix, "edge_indices": edge_indices, + "edge_attributes": edge_attributes}) + + def call(self, *, edge_indices: np.ndarray, edge_attributes: np.ndarray): + if edge_indices is None or edge_attributes is None: + return None + if len(edge_indices) == 0: + return None + max_index = np.amax(edge_indices)+1 + adj = np.zeros([max_index, max_index] + list(edge_attributes.shape[1:])) + adj[edge_indices[:, 0], edge_indices[:, 1]] = edge_attributes + return adj + + +class MakeMask(GraphPreProcessorBase): + r"""Make simple equally sized (dummy) mask for a graph property. + + Args: + target_property (str): Name of the property to make mask for. + mask_name (str): Name of mask to assign output. + rank (int): Rank of mask. + """ + + def __init__(self, *, target_property: str = "node_attributes", mask_name: str = None, + rank=1, + name="make_mask", **kwargs): + super().__init__(name=name, **kwargs) + if mask_name is None: + mask_name = target_property+"_mask" + self._to_obtain.update({"target_property": target_property}) + self._to_assign = mask_name + self._call_kwargs = {"rank": rank} + self._config_kwargs.update({"mask_name": mask_name, "target_property": target_property, "rank": rank}) + + def call(self, *, target_property: np.ndarray, rank: int): + if target_property is None: + return None + mask = np.ones(target_property.shape[:rank], dtype="bool") + return mask diff --git a/kgcnn/graph/serial.py b/kgcnn/graph/serial.py index 3c5d66dc..b0d02f8e 100644 --- a/kgcnn/graph/serial.py +++ b/kgcnn/graph/serial.py @@ -29,7 +29,9 @@ def get_preprocessor(name: Union[str, dict], **kwargs): "expand_distance_gaussian_basis": "ExpandDistanceGaussianBasis", "atomic_charge_representation": "AtomicChargesRepresentation", "principal_moments_of_inertia": "PrincipalMomentsOfInertia", - "count_nodes_and_edges": "CountNodesAndEdges" + "count_nodes_and_edges": "CountNodesAndEdges", + "make_dense_adjacency_matrix": "MakeDenseAdjacencyMatrix", + "make_mask": "MakeMask" } if isinstance(name, dict): return deserialize(name) diff --git a/kgcnn/literature/MAT/__init__.py b/kgcnn/literature/MAT/__init__.py index e69de29b..fc290485 100644 --- a/kgcnn/literature/MAT/__init__.py +++ b/kgcnn/literature/MAT/__init__.py @@ -0,0 +1,6 @@ +from ._make import make_model, model_default + +__all__ = [ + "make_model", + "model_default" +] diff --git a/kgcnn/literature/MAT/_layers.py b/kgcnn/literature/MAT/_layers.py index eeb758ad..feee25a0 100644 --- a/kgcnn/literature/MAT/_layers.py +++ b/kgcnn/literature/MAT/_layers.py @@ -61,6 +61,7 @@ def call(self, inputs, mask=None, **kwargs): diff = ops.expand_dims(inputs, axis=1) - ops.expand_dims(inputs, axis=2) dist = ops.sum(ops.square(diff), axis=-1, keepdims=True) # shape of dist (batch, N, N, 1) + mask = ops.cast(mask, dtype="float32") diff_mask = ops.expand_dims(mask, axis=1) * ops.expand_dims(mask, axis=2) dist_mask = ops.prod(diff_mask, axis=-1, keepdims=True) @@ -101,7 +102,8 @@ def call(self, inputs, **kwargs): Returns: Tensor: Product of inputs along axis. """ - return ops.prod(inputs, keepdims=self.keepdims, axis=self.axis) + out = ops.prod(inputs, keepdims=self.keepdims, axis=self.axis) + return out def get_config(self): config = super(MATReduceMask, self).get_config() @@ -127,7 +129,8 @@ def call(self, inputs, **kwargs): Returns: Tensor: Input tensor with expanded axis. """ - return ops.expand_dims(inputs, axis=self.axis) + out = ops.expand_dims(inputs, axis=self.axis) + return out def get_config(self): config = super(MATExpandMask, self).get_config() @@ -166,7 +169,7 @@ def call(self, inputs, mask=None, **kwargs): r"""Forward pass. Args: - inputs (list): List of [h_n, A_d, A_g] represented by padded :obj:`tf.Tensor` . + inputs (list): List of [h_n, A_d, A_g] represented by padded :obj:`Tensor` . These are node features and adjacency matrix from distances and bonds or bond order. mask (list): Mask tensors matching inputs, i.e. a mask tensor for each padded input. @@ -175,6 +178,7 @@ def call(self, inputs, mask=None, **kwargs): """ h, a_d, a_g = inputs h_mask, a_d_mask, a_g_mask = mask + h_mask = ops.cast(h_mask, dtype=h.dtype) q = ops.expand_dims(self.dense_q(h), axis=2) k = ops.expand_dims(self.dense_k(h), axis=1) v = self.dense_v(h) * h_mask diff --git a/kgcnn/literature/MAT/_make.py b/kgcnn/literature/MAT/_make.py index 002157d6..7591d296 100644 --- a/kgcnn/literature/MAT/_make.py +++ b/kgcnn/literature/MAT/_make.py @@ -26,11 +26,15 @@ "name": "MAT", "inputs": [ {"shape": (None,), "name": "node_number", "dtype": "int64"}, - + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, None, 1), "name": "adjacency_matrix", "dtype": "float32"}, + {"shape": (None,), "name": "node_mask", "dtype": "bool"}, + {"shape": (None, None), "name": "adjacency_mask", "dtype": "bool"}, ], - "input_embedding": {"node": {"input_dim": 95, "output_dim": 64}, - "edge": {"input_dim": 95, "output_dim": 64}}, - "use_edge_embedding": False, + "input_tensor_type": "padded", + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 95, "output_dim": 64}, "max_atoms": None, "distance_matrix_kwargs": {"trafo": "exp"}, "attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3, "lambda_adjacency": None, @@ -43,30 +47,34 @@ "verbose": 10, "pooling_kwargs": {"pooling_method": "sum"}, "output_embedding": "graph", - "output_to_tensor": True, + "output_to_tensor": None, "output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1], - "activation": ["relu", "relu", "linear"]} + "activation": ["relu", "relu", "linear"]}, + "output_tensor_type": "padded" } @update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_model(name: str = None, inputs: list = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_tensor_type: str = None, input_edge_embedding: dict = None, distance_matrix_kwargs: dict = None, attention_kwargs: dict = None, + max_atoms: int = None, feed_forward_kwargs:dict = None, embedding_units: int = None, depth: int = None, heads: int = None, merge_heads: str = None, - verbose: int = None, + verbose: int = None, # noqa pooling_kwargs: dict = None, output_embedding: str = None, - output_to_tensor: bool = None, - output_mlp: dict = None + output_to_tensor: bool = None, # noqa + output_mlp: dict = None, + output_tensor_type: str = None ): r"""Make `MAT `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.MAT.model_default` . @@ -84,7 +92,7 @@ def make_model(name: str = None, - adjacency_matrix (Tensor): Edge attributes of shape `(batch, N, N, F)` or `(batch, N, N)` using an embedding layer. - node_mask (Tensor): Node mask of shape `(batch, N)` - - adjacency_mask (Tensor): Adjacency mask of shape `(batch, N)` + - adjacency_mask (Tensor): Adjacency mask of shape `(batch, N, N)` Outputs: Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. @@ -114,6 +122,7 @@ def make_model(name: str = None, :obj:`keras.models.Model` """ assert input_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation." + assert output_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation." # Make input node_input = ks.layers.Input(**inputs[0]) xyz_input = ks.layers.Input(**inputs[1]) @@ -135,6 +144,8 @@ def make_model(name: str = None, n_mask = node_mask adj_mask = adjacency_mask xyz = xyz_input + n_mask = MATExpandMask(axis=-1)(n_mask) + adj_mask = MATExpandMask(axis=-1)(adj_mask) # Cast to dense Tensor with padding for MAT. # Nodes must have feature dimension. @@ -152,11 +163,9 @@ def make_model(name: str = None, if has_edge_dim: # Assume that feature-wise attention is not desired for adjacency, reduce to single value. adj = ks.layers.Dense(1, use_bias=False)(adj) - adj_mask = MATReduceMask(axis=-1, keepdims=True)(adj_mask) else: # Make sure that shape is (batch, max_atoms, max_atoms, 1). adj = MATExpandMask(axis=-1)(adj) - adj_mask = MATExpandMask(axis=-1)(adj_mask) # Repeat for depth. h_mask = n_mask @@ -202,7 +211,7 @@ def make_model(name: str = None, raise ValueError("Unsupported graph embedding for mode `MAT` .") model = ks.models.Model( - inputs=[node_input, xyz_input, adjacency_matrix], + inputs=[node_input, xyz_input, adjacency_matrix, node_mask, adjacency_mask], outputs=out, name=name ) diff --git a/kgcnn/literature/MAT/_model.py b/kgcnn/literature/MAT/_model.py deleted file mode 100644 index e69de29b..00000000 diff --git a/training/hyper/hyper_esol.py b/training/hyper/hyper_esol.py index 5839812d..8797172c 100644 --- a/training/hyper/hyper_esol.py +++ b/training/hyper/hyper_esol.py @@ -1586,4 +1586,85 @@ "kgcnn_version": "4.0.0" } }, + "MAT": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.MAT", + "config": { + "name": "MAT", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float64"}, + {"shape": (None, None, 11), "name": "adjacency_matrix", "dtype": "float64"}, + {"shape": (None,), "name": "node_mask", "dtype": "bool"}, + {"shape": (None, None), "name": "adjacency_mask", "dtype": "bool"}, + ], + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": None, + "distance_matrix_kwargs": {"trafo": "exp"}, + "attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3, + "lambda_adjacency": None, "add_identity": False, + "dropout": 0.1}, + "feed_forward_kwargs": {"units": [32, 32, 32], "activation": ["relu", "relu", "linear"]}, + "embedding_units": 32, + "depth": 5, + "heads": 8, + "merge_heads": "concat", + "verbose": 10, + "pooling_kwargs": {"pooling_method": "sum"}, + "output_embedding": "graph", + "output_to_tensor": True, + "output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1], + "activation": ["relu", "relu", "linear"]} + } + }, + "training": { + "fit": { + "batch_size": 32, + "epochs": 400, + "validation_freq": 10, + "verbose": 2, + "callbacks": [ + { + "class_name": "kgcnn>LinearLearningRateScheduler", + "config": { + "learning_rate_start": 5e-04, "learning_rate_stop": 1e-05, "epo_min": 0, "epo": 400, + "verbose": 0 + } + } + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 5e-04}}, + "loss": "mean_absolute_error" + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"map_list": {"method": "set_edge_weights_uniform"}}, + {"map_list": {"method": "make_dense_adjacency_matrix"}}, + {"map_list": {"method": "make_mask", "target_property": "node_number", + "mask_name": "node_mask", "rank": 1}}, + {"map_list": {"method": "make_mask", "target_property": "adjacency_matrix", + "mask_name": "adjacency_mask", "rank": 2}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, }