From 0a27b26fa782a09baada713d931940a6e8566244 Mon Sep 17 00:00:00 2001 From: PatReis Date: Wed, 15 Nov 2023 16:52:54 +0100 Subject: [PATCH] update for keras 3.0 --- changelog.md | 12 ++- kgcnn/literature/AttentiveFP/_make.py | 26 ++--- kgcnn/literature/DMPNN/_make.py | 61 ++++++------ kgcnn/literature/GAT/_make.py | 34 ++++--- kgcnn/literature/GATv2/_make.py | 47 ++++----- kgcnn/literature/GCN/_make.py | 55 +++++------ kgcnn/literature/GIN/_make.py | 100 ++++++++++--------- kgcnn/literature/GraphSAGE/_make.py | 55 +++++------ kgcnn/literature/PAiNN/_make.py | 137 +++++++++++++++----------- kgcnn/literature/Schnet/_make.py | 110 ++++++++++----------- training/hyper/hyper_esol.py | 62 ++++++++++++ training/train_graph.py | 5 +- 12 files changed, 396 insertions(+), 308 deletions(-) diff --git a/changelog.md b/changelog.md index 9feeb268..87998651 100644 --- a/changelog.md +++ b/changelog.md @@ -1,9 +1,19 @@ v4.0.0 -* Reworked training scripts to have a single ``train_graph.py`` script. Command line arguments are now optional and just used for verification, but `category` to select a model/hyperparameter combination from hyper file. +Completely reworked version of kgcnn for Keras 3.0 and multi-backend support. A lot of fundamental changes have been made. +However, we tried to keep as much of the API from kgcnn 3.0 so that models in literature can be used with minimal changes. +Mainly, the ``"input_tensor_type"="ragged"`` model parameter has to be added if ragged tensors are used as input in tensorflow. +The scope of models has been reduced for initial release but will be extended in upcoming versions. + +* Reworked training scripts to have a single ``train_graph.py`` script. Command line arguments are now optional and just used for verification, all but `category` which has to select a model/hyperparameter combination from hyper file. Since the hyperparameter file already contains all necessary information. * Train test indices can now also be set and loaded from the dataset directly. * Scaler behaviour has changed with regard to `transform_dataset`. Key names of properties to transform has been moved to the constructor! +Also be sure to check ``StandardLabelScaler`` if you want to scale regression targets, since target properties are default here. +* Literature Models have an optional output scaler from new `kgcnn.layers.scale` layer controlled by `output_scaling` model argument. +* Input embedding in literature models is now controlled with separate ``input_node_embedding`` or ``input_edge_embedding`` arguments which can be set to `None` for no embedding. +Also embedding input tokens must be of dtype int now. No autocasting from float anymore. +* diff --git a/kgcnn/literature/AttentiveFP/_make.py b/kgcnn/literature/AttentiveFP/_make.py index 6b6ba2bf..bcd0ab33 100644 --- a/kgcnn/literature/AttentiveFP/_make.py +++ b/kgcnn/literature/AttentiveFP/_make.py @@ -36,7 +36,7 @@ ], "input_tensor_type": "padded", "cast_disjoint_kwargs": {}, - "input_embedding": None, + "input_embedding": None, # deprecated "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, "attention_args": {"units": 32}, @@ -44,7 +44,8 @@ "depthato": 2, "dropout": 0.1, "verbose": 10, - "output_embedding": "graph", "output_to_tensor": True, + "output_embedding": "graph", + "output_to_tensor": True, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ["relu", "relu", "sigmoid"]} @@ -57,36 +58,37 @@ def make_model(inputs: list = None, input_tensor_type: str = None, input_node_embedding: dict = None, input_edge_embedding: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa depthmol: int = None, depthato: int = None, dropout: float = None, attention_args: dict = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_tensor_type: str = None, output_scaling: dict = None, output_mlp: dict = None ): - r"""Make `AttentiveFP `_ graph network via functional API. + r"""Make `AttentiveFP `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.AttentiveFP.model_default`. Model inputs: Model uses the list template of inputs and standard output template. The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` with '...' indicating mask or id tensors following the template below: + %s Model outputs: The standard output template: - %s + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. input_tensor_type (str): Input type of graph tensor. Default is "padded". input_embedding (dict): Deprecated in favour of `input_node_embedding` etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. @@ -119,8 +121,8 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( [n, ed, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, - use_edge_embedding=len(inputs[1]['shape']) < 2, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, depthmol=depthmol, @@ -154,4 +156,4 @@ def set_scale(*args, **kwargs): return model -make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) \ No newline at end of file +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index ccf45108..f335c9e2 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -27,8 +27,8 @@ model_default = { "name": "DMPNN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, - {"shape": (None,), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (None, 1), "name": "edge_indices_reverse", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, @@ -48,7 +48,8 @@ "edge_activation": {"activation": "relu"}, "node_dense": {"units": 128, "use_bias": True, "activation": "relu"}, "verbose": 10, "depth": 5, "dropout": {"rate": 0.1}, - "output_embedding": "graph", "output_to_tensor": True, + "output_embedding": "graph", + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True, False], "units": [64, 32, 1], "activation": ["relu", "relu", "linear"]}, @@ -61,7 +62,7 @@ def make_model(name: str = None, inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, input_graph_embedding: dict = None, @@ -72,10 +73,10 @@ def make_model(name: str = None, node_dense: dict = None, dropout: dict = None, depth: int = None, - verbose: int = None, + verbose: int = None, # noqa use_graph_state: bool = False, output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_tensor_type: str = None, output_mlp: dict = None, output_scaling: dict = None @@ -83,30 +84,25 @@ def make_model(name: str = None, r"""Make `DMPNN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.DMPNN.model_default`. - Inputs: - list: `[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges]` or - `[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges, state_attributes]` - if `use_graph_state=True` . - - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`. - - edge_pairs (Tensor): Pair mappings for reverse edge for each edge `(batch, None, 1)`. - - state_attributes (Tensor): Environment or graph state attributes of shape `(batch, F)` or `(batch,)` - using an embedding layer. - - total_nodes(Tensor): Number of Nodes in graph of shape `(batch, )` . - - total_edges(Tensor): Number of Edges in graph of shape `(batch, )` . - - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, reverse_indices, (graph_state), ...]` + with '...' indicating mask or id tensors following the template below. + Here, reverse indices are in place of angle indices and refer to edges. The graph state is optional and controlled + by `use_graph_state` parameter. + + %s + + Model outputs: + The standard output template: + + %s Args: name (str): Name of the model. Should be "DMPNN". - inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. @@ -122,7 +118,7 @@ def make_model(name: str = None, verbose (int): Level for print information. use_graph_state (bool): Whether to use graph state information. Default is False. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. + output_to_tensor (bool): WDeprecated in favour of `output_tensor_type` . output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. @@ -150,10 +146,12 @@ def make_model(name: str = None, # Wrapping disjoint model. out = model_disjoint( [n, ed, edi, batch_id_node, e_pairs, count_nodes, gs], - use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, - use_graph_embedding=len(inputs[7]["shape"]) < 1 if use_graph_state else False, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + use_graph_embedding=("int" in inputs[4]['dtype']) if input_graph_embedding is not None else False, input_node_embedding=input_node_embedding, - input_edge_embedding=input_edge_embedding, input_graph_embedding=input_graph_embedding, + input_edge_embedding=input_edge_embedding, + input_graph_embedding=input_graph_embedding, pooling_args=pooling_args, edge_initialize=edge_initialize, edge_activation=edge_activation, node_dense=node_dense, dropout=dropout, depth=depth, use_graph_state=use_graph_state, output_embedding=output_embedding, output_mlp=output_mlp, edge_dense=edge_dense @@ -180,3 +178,6 @@ def set_scale(*args, **kwargs): setattr(model, "set_scale", set_scale) return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GAT/_make.py b/kgcnn/literature/GAT/_make.py index 6523e1eb..021b299e 100644 --- a/kgcnn/literature/GAT/_make.py +++ b/kgcnn/literature/GAT/_make.py @@ -24,8 +24,8 @@ model_default = { "name": "GAT", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, - {"shape": (None,), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -34,15 +34,16 @@ "cast_disjoint_kwargs": {}, "input_embedding": None, # deprecated "input_node_embedding": {"input_dim": 95, "output_dim": 64}, - "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 10, "output_dim": 64}, "attention_args": {"units": 32, "use_final_activation": False, "use_edge_features": True, "has_self_loops": True, "activation": "kgcnn>leaky_relu", "use_bias": True}, "pooling_nodes_args": {"pooling_method": "scatter_mean"}, - "depth": 3, "attention_heads_num": 5, + "depth": 3, + "attention_heads_num": 5, "attention_heads_concat": False, "verbose": 10, "output_embedding": "graph", - "output_to_tensor": True, + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ["relu", "relu", "sigmoid"]}, @@ -54,7 +55,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, attention_args: dict = None, @@ -63,14 +64,14 @@ def make_model(inputs: list = None, attention_heads_num: int = None, attention_heads_concat: bool = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None, ): - r"""Make `GAT `_ graph network via functional API. + r"""Make `GAT `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.GAT.model_default`. Model inputs: @@ -85,10 +86,9 @@ def make_model(inputs: list = None, %s - Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. input_tensor_type (str): Input type of graph tensor. Default is "padded". input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. @@ -108,7 +108,7 @@ def make_model(inputs: list = None, output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] @@ -121,8 +121,10 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( [n, ed, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, - input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, attention_args=attention_args, pooling_nodes_args=pooling_nodes_args, depth=depth, attention_heads_num=attention_heads_num, attention_heads_concat=attention_heads_concat, output_embedding=output_embedding, output_mlp=output_mlp @@ -151,4 +153,4 @@ def set_scale(*args, **kwargs): return model -make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) \ No newline at end of file +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GATv2/_make.py b/kgcnn/literature/GATv2/_make.py index 9fa81149..2401aecc 100644 --- a/kgcnn/literature/GATv2/_make.py +++ b/kgcnn/literature/GATv2/_make.py @@ -44,7 +44,8 @@ "pooling_nodes_args": {"pooling_method": "scatter_mean"}, 'depth': 3, 'attention_heads_num': 5, 'attention_heads_concat': False, 'verbose': 10, - 'output_embedding': 'graph', "output_to_tensor": True, + 'output_embedding': 'graph', + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", 'output_mlp': {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['relu', 'relu', 'sigmoid']}, @@ -56,7 +57,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, attention_args: dict = None, @@ -65,9 +66,9 @@ def make_model(inputs: list = None, attention_heads_num: int = None, attention_heads_concat: bool = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None, @@ -75,23 +76,21 @@ def make_model(inputs: list = None, r"""Make `GATv2 `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.GATv2.model_default`. - Inputs: - list: `[node_attributes, edge_attributes, edge_indices, total_nodes, total_edges]` + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below: - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`. - - total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + %s - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model outputs: + The standard output template: + + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`ks.layers.Input`. Order must match model definition. - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. input_tensor_type (str): Input type of graph tensor. Default is "padded". input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. @@ -104,27 +103,28 @@ def make_model(inputs: list = None, name (str): Name of the model. verbose (int): Level of print output. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_list_input( + dj_model_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) - n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_model_inputs # Wrapping disjoint model. out = model_disjoint( [n, ed, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, attention_args=attention_args, pooling_nodes_args=pooling_nodes_args, depth=depth, attention_heads_num=attention_heads_num, attention_heads_concat=attention_heads_concat, @@ -152,3 +152,6 @@ def set_scale(*args, **kwargs): setattr(model, "set_scale", set_scale) return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GCN/_make.py b/kgcnn/literature/GCN/_make.py index 3f39d7fb..df997b71 100644 --- a/kgcnn/literature/GCN/_make.py +++ b/kgcnn/literature/GCN/_make.py @@ -1,8 +1,4 @@ import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint, - CastDisjointToBatchedGraphState, CastDisjointToBatchedAttributes, - CastBatchedGraphStateToDisjoint, CastRaggedAttributesToDisjoint, - CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) from kgcnn.layers.scale import get as get_scaler from ._model import model_disjoint from kgcnn.layers.modules import Input @@ -31,7 +27,7 @@ model_default = { "name": "GCN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 1), "name": "edge_weights", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, @@ -47,7 +43,7 @@ "verbose": 10, "node_pooling_args": {"pooling_method": "scatter_sum"}, "output_embedding": "graph", - "output_to_tensor": True, + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ["relu", "relu", "sigmoid"]}, @@ -59,40 +55,39 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, depth: int = None, gcn_args: dict = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa node_pooling_args: dict = None, output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_tensor_type: str = None, output_mlp: dict = None, output_scaling: dict = None): - r"""Make `GCN `_ graph network via functional API. + r"""Make `GCN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.GCN.model_default`. - Inputs: - list: `[node_attributes, edge_weights, edge_indices, total_nodes, total_edges]` + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix. - - node_attributes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` - using an embedding layer. - - edge_weights (Tensor): Edge weights of shape `(batch, M, 1)` , that are entries of a scaled - adjacency matrix. - - edge_indices (Tensor): Index list for edges of shape `(batch, M, 2)` . - - total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + %s - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model outputs: + The standard output template: + + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`ks.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. input_edge_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. @@ -102,14 +97,14 @@ def make_model(inputs: list = None, verbose (int): Level of print output. node_pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ if inputs[1]['shape'][-1] != 1: raise ValueError("No edge features available for GCN, only edge weights of pre-scaled adjacency matrix, \ @@ -118,14 +113,15 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_list_input( + dj_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) - n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs out = model_disjoint( [n, ed, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, depth=depth, gcn_args=gcn_args, node_pooling_args=node_pooling_args, output_embedding=output_embedding, output_mlp=output_mlp @@ -151,3 +147,6 @@ def set_scale(*args, **kwargs): setattr(model, "set_scale", set_scale) return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GIN/_make.py b/kgcnn/literature/GIN/_make.py index fabcf1a1..37453cd2 100644 --- a/kgcnn/literature/GIN/_make.py +++ b/kgcnn/literature/GIN/_make.py @@ -1,8 +1,4 @@ -import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint, - CastDisjointToBatchedGraphState, CastDisjointToBatchedAttributes, - CastBatchedGraphStateToDisjoint, CastRaggedAttributesToDisjoint, - CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) +# import keras_core as ks from ._model import model_disjoint, model_disjoint_edge from kgcnn.models.utils import update_model_kwargs from kgcnn.layers.scale import get as get_scaler @@ -28,7 +24,7 @@ model_default = { "name": "GIN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -44,7 +40,8 @@ "depth": 3, "dropout": 0.0, "verbose": 10, "last_mlp": {"use_bias": [True, True, True], "units": [64, 64, 64], "activation": ["relu", "relu", "linear"]}, - "output_embedding": 'graph', "output_to_tensor": True, + "output_embedding": 'graph', + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": True, "units": 1, "activation": "softmax"} @@ -55,7 +52,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, depth: int = None, gin_args: dict = None, @@ -63,32 +60,32 @@ def make_model(inputs: list = None, last_mlp: dict = None, dropout: float = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None, ): - r"""Make `GIN `_ graph network via functional API. + r"""Make `GIN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.GIN.model_default`. - Inputs: - list: `[node_attributes, edge_attributes, edge_indices, total_nodes, total_edges]` + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below: + + %s - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`. - - total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + Model outputs: + The standard output template: - Outputs: - tf.Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + %s Args: inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + cast_disjoint_kwargs (dict): Dictionary of arguments for castin layers. input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. depth (int): Number of graph embedding units or depth of the network. @@ -99,14 +96,14 @@ def make_model(inputs: list = None, name (str): Name of the model. verbose (int): Level of print output. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`tf.keras.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] @@ -121,8 +118,10 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( [n, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, input_node_embedding=input_node_embedding, - depth=depth, gin_args=gin_args, gin_mlp=gin_mlp, last_mlp=last_mlp, dropout=dropout, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + depth=depth, gin_args=gin_args, gin_mlp=gin_mlp, + last_mlp=last_mlp, dropout=dropout, output_embedding=output_embedding, output_mlp=output_mlp ) @@ -148,11 +147,14 @@ def set_scale(*args, **kwargs): return model +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + + model_default_edge = { "name": "GINE", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, - {"shape": (None,), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -161,7 +163,7 @@ def set_scale(*args, **kwargs): "cast_disjoint_kwargs": {}, "input_embedding": None, # deprecated "input_node_embedding": {"input_dim": 95, "output_dim": 64}, - "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 10, "output_dim": 64}, "gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"], "use_normalization": True, "normalization_technique": "graph_batch"}, "gin_args": {"epsilon_learnable": False}, @@ -179,7 +181,7 @@ def set_scale(*args, **kwargs): def make_model_edge(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, depth: int = None, @@ -188,9 +190,9 @@ def make_model_edge(inputs: list = None, last_mlp: dict = None, dropout: float = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None, @@ -198,24 +200,22 @@ def make_model_edge(inputs: list = None, r"""Make `GINE `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.GIN.model_default_edge` . - Inputs: - list: `[node_attributes, edge_attributes, edge_indices, total_nodes, total_edges]` + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below: + + %s - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`. - - total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + Model outputs: + The standard output template: - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers. input_embedding (dict): Deprecated in favour of input_node_embedding etc. input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. @@ -227,14 +227,14 @@ def make_model_edge(inputs: list = None, name (str): Name of the model. verbose (int): Level of print output. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] @@ -247,9 +247,10 @@ def make_model_edge(inputs: list = None, # Wrapping disjoint model. out = model_disjoint_edge( [n, ed, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding="float" not in inputs[0]['dtype'], - use_edge_embedding="float" not in inputs[1]['dtype'], - input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, depth=depth, gin_args=gin_args, gin_mlp=gin_mlp, last_mlp=last_mlp, dropout=dropout, output_embedding=output_embedding, output_mlp=output_mlp ) @@ -276,3 +277,4 @@ def set_scale(*args, **kwargs): return model +make_model_edge.__doc__ = make_model_edge.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GraphSAGE/_make.py b/kgcnn/literature/GraphSAGE/_make.py index 722f1bea..18e822db 100644 --- a/kgcnn/literature/GraphSAGE/_make.py +++ b/kgcnn/literature/GraphSAGE/_make.py @@ -1,8 +1,4 @@ import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint, - CastDisjointToBatchedGraphState, CastDisjointToBatchedAttributes, - CastBatchedGraphStateToDisjoint, CastRaggedAttributesToDisjoint, - CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) from ._model import model_disjoint from kgcnn.layers.modules import Input from kgcnn.models.utils import update_model_kwargs @@ -28,8 +24,8 @@ model_default = { 'name': "GraphSAGE", 'inputs': [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, - {"shape": (None,), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -45,7 +41,8 @@ 'concat_args': {"axis": -1}, 'use_edge_features': True, 'pooling_nodes_args': {'pooling_method': "scatter_mean"}, 'depth': 3, 'verbose': 10, - 'output_embedding': 'graph', "output_to_tensor": True, + 'output_embedding': 'graph', + "output_to_tensor": None, # deprecated "output_tensor_type": "padded", 'output_mlp': {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['relu', 'relu', 'sigmoid']}, @@ -57,7 +54,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_edge_embedding: dict = None, node_mlp_args: dict = None, @@ -69,33 +66,31 @@ def make_model(inputs: list = None, use_edge_features: bool = None, depth: int = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, output_tensor_type: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None ): r"""Make `GraphSAGE `__ graph network via functional API. - Default parameters can be found in :obj:`kgcnn.literature.GraphSAGE.model_default` . - Inputs: - list: `[node_attributes, edge_attributes, edge_indices, total_nodes, total_edges]` + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix. + + %s - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`. - - total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + Model outputs: + The standard output template: - Outputs: - tf.Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. @@ -112,26 +107,27 @@ def make_model(inputs: list = None, name (str): Name of the model. verbose (int): Level of print output. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_list_input( + dj_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) - n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs out = model_disjoint( [n, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges], - use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, node_mlp_args=node_mlp_args, edge_mlp_args=edge_mlp_args, pooling_args=pooling_args, pooling_nodes_args=pooling_nodes_args, gather_args=gather_args, concat_args=concat_args, @@ -160,3 +156,6 @@ def set_scale(*args, **kwargs): setattr(model, "set_scale", set_scale) return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/PAiNN/_make.py b/kgcnn/literature/PAiNN/_make.py index f72a7a9e..51cbac7a 100644 --- a/kgcnn/literature/PAiNN/_make.py +++ b/kgcnn/literature/PAiNN/_make.py @@ -23,7 +23,7 @@ model_default = { "name": "PAiNN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, @@ -43,7 +43,8 @@ "depth": 3, "verbose": 10, "output_embedding": "graph", - "output_to_tensor": True, "output_tensor_type": "padded", + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", "output_scaling": None, "output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]} } @@ -53,7 +54,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, has_equivariant_input: bool = None, equiv_initialize_kwargs: dict = None, @@ -65,39 +66,40 @@ def make_model(inputs: list = None, equiv_normalization: bool = None, node_normalization: bool = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None ): - r"""Make `PAiNN `_ graph network via functional API. + r"""Make `PAiNN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.PAiNN.model_default`. - Inputs: - list: `[node_attributes, node_coordinates, bond_indices, total_nodes, total_edges]` - or `[node_attributes, node_coordinates, bond_indices, total_nodes, total_edges, equiv_initial]` - if a custom equivariant initialization is chosen other than zero. + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + If equivariant input is used via `has_equivariant_input` then input is extended to + :obj:`[equiv, nodes, coordinates, edge_indices, ...]` - - node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - node_coordinates (Tensor): Atomic coordinates of shape `(batch, None, 3)`. - - bond_indices (Tensor): Index list for edges or bonds of shape `(batch, None, 2)`. - - equiv_initial (Tensor): Equivariant initialization `(batch, None, 3, F)`. Optional. - - total_nodes(Tensor): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor): Number of Edges in graph if not same sized graphs of shape `(batch, )` . + %s - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model outputs: + The standard output template: + + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. - cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers. input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. equiv_initialize_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`EquivariantInitialize` layer. bessel_basis (dict): Dictionary of layer arguments unpacked in final :obj:`BesselBasisLayer` layer. depth (int): Number of graph embedding units or depth of the network. + has_equivariant_input (bool): Whether the first equivariant node embedding is passed to the model. pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. conv_args (dict): Dictionary of layer arguments unpacked in :obj:`PAiNNconv` layer. update_args (dict): Dictionary of layer arguments unpacked in :obj:`PAiNNUpdate` layer. @@ -106,10 +108,11 @@ def make_model(inputs: list = None, verbose (int): Level of verbosity. name (str): Name of the model. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: :obj:`keras.models.Model` @@ -119,7 +122,7 @@ def make_model(inputs: list = None, disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=2+int(has_equivariant_input), + has_nodes=2 + int(has_equivariant_input), has_edges=False) if not has_equivariant_input: @@ -131,8 +134,9 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( [z, x, edi, batch_id_node, batch_id_edge, count_nodes, count_edges, v], - use_node_embedding=len(inputs[0]['shape']) < 2, - input_node_embedding=input_node_embedding, equiv_initialize_kwargs=equiv_initialize_kwargs, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + equiv_initialize_kwargs=equiv_initialize_kwargs, bessel_basis=bessel_basis, depth=depth, pooling_args=pooling_args, conv_args=conv_args, update_args=update_args, equiv_normalization=equiv_normalization, node_normalization=node_normalization, output_embedding=output_embedding, output_mlp=output_mlp @@ -165,14 +169,19 @@ def set_scale(*args, **kwargs): return model +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + + model_crystal_default = { "name": "PAiNN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "int64", "ragged": True}, - {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True}, - {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True}, - {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True}, - {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64'}, + {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32'}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} ], "input_tensor_type": "padded", "input_embedding": None, # deprecated @@ -197,7 +206,7 @@ def set_scale(*args, **kwargs): @update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_crystal_model(inputs: list = None, input_tensor_type: str = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa cast_disjoint_kwargs: dict = None, has_equivariant_input: bool = None, input_node_embedding: dict = None, @@ -210,40 +219,41 @@ def make_crystal_model(inputs: list = None, equiv_normalization: bool = None, node_normalization: bool = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa output_embedding: str = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None ): - r"""Make `PAiNN `_ graph network via functional API. + r"""Make `PAiNN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.PAiNN.model_crystal_default`. - Inputs: - list: `[node_attributes, node_coordinates, bond_indices, edge_image, lattice]` - or `[node_attributes, node_coordinates, bond_indices, edge_image, lattice, equiv_initial]` if a custom - equivariant initialization is chosen other than zero. + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, image_translation, lattice, ...]` + with '...' indicating mask or ID tensors following the template below. + If equivariant input is used via `has_equivariant_input` then input is extended to + :obj:`[equiv, nodes, coordinates, edge_indices, image_translation, lattice, ...]` + + %s - - node_attributes (tf.RaggedTensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` - using an embedding layer. - - node_coordinates (tf.RaggedTensor): Atomic coordinates of shape `(batch, None, 3)`. - - bond_indices (tf.RaggedTensor): Index list for edges or bonds of shape `(batch, None, 2)`. - - equiv_initial (tf.RaggedTensor): Equivariant initialization `(batch, None, 3, F)`. Optional. - - lattice (tf.Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)`. - - edge_image (tf.RaggedTensor): Indices of the periodic image the sending node is located. The indices - of and edge are :math:`(i, j)` with :math:`j` being the sending node. + Model outputs: + The standard output template: - Outputs: - tf.Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. - input_embedding (dict): Dictionary of embedding arguments for nodes etc. unpacked in :obj:`Embedding` layers. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. bessel_basis (dict): Dictionary of layer arguments unpacked in final :obj:`BesselBasisLayer` layer. equiv_initialize_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`EquivariantInitialize` layer. depth (int): Number of graph embedding units or depth of the network. pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. + has_equivariant_input (bool): Whether the first equivariant node embedding is passed to the model. conv_args (dict): Dictionary of layer arguments unpacked in :obj:`PAiNNconv` layer. update_args (dict): Dictionary of layer arguments unpacked in :obj:`PAiNNUpdate` layer. equiv_normalization (bool): Whether to apply :obj:`GraphLayerNormalization` to equivariant tensor update. @@ -251,31 +261,34 @@ def make_crystal_model(inputs: list = None, verbose (int): Level of verbosity. name (str): Name of the model. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`tf.keras.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=2+int(has_equivariant_input), has_crystal_input=2, - has_edges=False) + dj_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=2 + int(has_equivariant_input), has_crystal_input=2, + has_edges=False) if not has_equivariant_input: - z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs + z, x, edi, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs v = None else: - v, z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs + v, z, x, edi, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs # Wrapping disjoint model. out = model_disjoint_crystal( - [z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges, v], - use_node_embedding=len(inputs[0]['shape']) < 2, + [z, x, edi, img, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges, v], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, input_node_embedding=input_node_embedding, equiv_initialize_kwargs=equiv_initialize_kwargs, bessel_basis=bessel_basis, depth=depth, pooling_args=pooling_args, conv_args=conv_args, update_args=update_args, equiv_normalization=equiv_normalization, node_normalization=node_normalization, @@ -308,3 +321,7 @@ def set_scale(*args, **kwargs): model.__kgcnn_model_version__ = __model_version__ return model + + +make_crystal_model.__doc__ = make_crystal_model.__doc__ % ( + template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/Schnet/_make.py b/kgcnn/literature/Schnet/_make.py index f46a8fb5..1d55b0d1 100644 --- a/kgcnn/literature/Schnet/_make.py +++ b/kgcnn/literature/Schnet/_make.py @@ -1,8 +1,4 @@ import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint, - CastDisjointToBatchedGraphState, CastDisjointToBatchedAttributes, - CastBatchedGraphStateToDisjoint, CastRaggedAttributesToDisjoint, - CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) from kgcnn.layers.scale import get as get_scaler from ._model import model_disjoint, model_disjoint_crystal from kgcnn.layers.modules import Input @@ -49,7 +45,8 @@ "verbose": 10, "last_mlp": {"use_bias": [True, True], "units": [128, 64], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]}, - "output_embedding": "graph", "output_to_tensor": True, + "output_embedding": "graph", + "output_to_tensor": None, # deprecated "use_output_mlp": True, "output_tensor_type": "padded", "output_scaling": None, @@ -62,7 +59,7 @@ def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, make_distance: bool = None, expand_distance: bool = None, @@ -71,11 +68,11 @@ def make_model(inputs: list = None, node_pooling_args: dict = None, depth: int = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa last_mlp: dict = None, output_embedding: str = None, use_output_mlp: bool = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_tensor_type: str = None, output_scaling: dict = None @@ -83,27 +80,22 @@ def make_model(inputs: list = None, r"""Make `SchNet `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.Schnet.model_default` . - Inputs: - list: `[node_attributes, edge_distance, edge_indices, total_nodes, total_edges]` - or `[node_attributes, node_coordinates, edge_indices, total_nodes, total_edges]` - if :obj:`make_distance=True` and :obj:`expand_distance=True` - to compute edge distances from node coordinates within the model. - - - node_attributes (tf.RaggedTensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` - using an embedding layer. - - edge_distance (tf.RaggedTensor): Edge distance of shape `(batch, M, D)` expanded - in a basis of dimension `D` or `(batch, M, 1)` if using a :obj:`GaussBasisLayer` layer - with model argument :obj:`expand_distance=True` and the numeric distance between nodes. - - edge_indices (tf.RaggedTensor): Index list for edges of shape `(batch, M, 2)`. - - node_coordinates (tf.RaggedTensor): Node (atomic) coordinates of shape `(batch, None, 3)`. - - total_nodes(Tensor): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor): Number of Edges in graph if not same sized graphs of shape `(batch, )` . - - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, ...]` with `make_distance` and + with '...' indicating mask or ID tensors following the template below. + Note that you could also supply edge features with `make_distance` to False, which would make the input + :obj:`[nodes, edges, edge_indices, ...]` . + + %s + + Model outputs: + The standard output template: + + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. @@ -121,14 +113,14 @@ def make_model(inputs: list = None, last_mlp (dict): Dictionary of layer arguments unpacked in last :obj:`MLP` layer before output or pooling. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". use_output_mlp (bool): Whether to use the final output MLP. Possibility to skip final MLP. - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". Returns: - :obj:`ks.models.Model` + :obj:`keras.models.Model` """ # Make input model_inputs = [Input(**x) for x in inputs] @@ -141,7 +133,7 @@ def make_model(inputs: list = None, out = model_disjoint( [n, x, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, input_node_embedding=input_node_embedding, make_distance=make_distance, expand_distance=expand_distance, gauss_args=gauss_args, interaction_args=interaction_args, node_pooling_args=node_pooling_args, depth=depth, @@ -176,10 +168,13 @@ def set_scale(*args, **kwargs): return model +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + + model_crystal_default = { "name": "Schnet", "inputs": [ - {"shape": (None,), "name": "node_number", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (None, 3), "name": "edge_image", "dtype": "int64"}, @@ -200,7 +195,7 @@ def set_scale(*args, **kwargs): "last_mlp": {"use_bias": [True, True], "units": [128, 64], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]}, "output_embedding": "graph", - "output_to_tensor": True, # deprecated + "output_to_tensor": None, # deprecated "use_output_mlp": True, "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True], "units": [64, 1], @@ -212,7 +207,7 @@ def set_scale(*args, **kwargs): def make_crystal_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, - input_embedding: dict = None, + input_embedding: dict = None, # noqa input_node_embedding: dict = None, make_distance: bool = None, expand_distance: bool = None, @@ -221,11 +216,11 @@ def make_crystal_model(inputs: list = None, node_pooling_args: dict = None, depth: int = None, name: str = None, - verbose: int = None, + verbose: int = None, # noqa last_mlp: dict = None, output_embedding: str = None, use_output_mlp: bool = None, - output_to_tensor: bool = None, + output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_scaling: dict = None, output_tensor_type: str = None, @@ -233,29 +228,20 @@ def make_crystal_model(inputs: list = None, r"""Make `SchNet `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.Schnet.model_crystal_default`. - Inputs: - list: `[node_attributes, edge_distance, edge_indices, edge_image, lattice, total_nodes, total_edges]` - or `[node_attributes, node_coordinates, edge_indices, edge_image, lattice, total_nodes, total_edges]` - if :obj:`make_distance=True` and :obj:`expand_distance=True` to compute edge distances from node coordinates - within the model. - - - node_attributes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` - using an embedding layer. - - edge_distance (Tensor): Edge distance of shape `(batch, M, D)` expanded - in a basis of dimension `D` or `(batch, M, 1)` if using a :obj:`GaussBasisLayer` layer - with model argument :obj:`expand_distance=True` and the numeric distance between nodes. - - edge_indices (Tensor): Index list for edges of shape `(batch, M, 2)`. - - node_coordinates (Tensor): Node (atomic) coordinates of shape `(batch, None, 3)`. - - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)`. - - edge_image (Tensor): Indices of the periodic image the sending node is located. - - total_nodes(Tensor): Number of Nodes in graph if not same sized graphs of shape `(batch, )` . - - total_edges(Tensor): Number of Edges in graph if not same sized graphs of shape `(batch, )` . - - Outputs: - Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, image_translation, lattice, ...]` + with '...' indicating mask or ID tensors following the template below. + + %s + + Model outputs: + The standard output template: + + %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. @@ -273,7 +259,7 @@ def make_crystal_model(inputs: list = None, last_mlp (dict): Dictionary of layer arguments unpacked in last :obj:`MLP` layer before output or pooling. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". use_output_mlp (bool): Whether to use the final output MLP. Possibility to skip final MLP. - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. @@ -290,12 +276,12 @@ def make_crystal_model(inputs: list = None, has_edges=(not make_distance), has_nodes=1 + int(make_distance), has_crystal_input=2) - n, x, disjoint_indices, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs + n, x, djx, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs # Wrapp disjoint model out = model_disjoint_crystal( - [n, x, disjoint_indices, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, + [n, x, djx, img, lattice, batch_id_node, batch_id_edge, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, input_node_embedding=input_node_embedding, make_distance=make_distance, expand_distance=expand_distance, gauss_args=gauss_args, interaction_args=interaction_args, node_pooling_args=node_pooling_args, depth=depth, last_mlp=last_mlp, @@ -328,3 +314,7 @@ def set_scale(*args, **kwargs): setattr(model, "set_scale", set_scale) return model + + +make_crystal_model.__doc__ = make_crystal_model.__doc__ % ( + template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/training/hyper/hyper_esol.py b/training/hyper/hyper_esol.py index 8e28a278..437f486d 100644 --- a/training/hyper/hyper_esol.py +++ b/training/hyper/hyper_esol.py @@ -468,4 +468,66 @@ "kgcnn_version": "4.0.0" } }, + "PAiNN": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.PAiNN", + "config": { + "name": "PAiNN", + "inputs": [ + {"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True}, + {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True}, + {"shape": [None, 2], "name": "range_indices", "dtype": "int64", "ragged": True} + ], + "input_tensor_type": "ragged", + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 128}, + "bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5}, + "pooling_args": {"pooling_method": "scatter_sum"}, "conv_args": {"units": 128, "cutoff": None}, + "update_args": {"units": 128}, "depth": 3, "verbose": 10, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]}, + } + }, + "training": { + "fit": { + "batch_size": 32, "epochs": 250, "validation_freq": 10, "verbose": 2, + "callbacks": [] + }, + "compile": { + "optimizer": { + "class_name": "Adam", "config": { + "learning_rate": { + "class_name": "kgcnn>LinearWarmupExponentialDecay", "config": { + "learning_rate": 0.001, "warmup_steps": 30.0, "decay_steps": 40000.0, + "decay_rate": 0.01 + } + }, "amsgrad": True, "use_ema": True + } + }, + "loss": "mean_absolute_error", + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}}, + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {"add_hydrogen": True}}, + {"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 10000}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, } \ No newline at end of file diff --git a/training/train_graph.py b/training/train_graph.py index 9810ad81..c5565f6b 100644 --- a/training/train_graph.py +++ b/training/train_graph.py @@ -4,6 +4,7 @@ import argparse import time import kgcnn.training.scheduler # noqa +import kgcnn.training.schedule # noqa from datetime import timedelta import kgcnn.losses.losses import kgcnn.metrics.metrics @@ -22,8 +23,8 @@ # for training and model setup. parser = argparse.ArgumentParser(description='Train a GNN on a graph regression or classification task.') parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).", - default="hyper/hyper_mp_jdft2d.py") -parser.add_argument("--category", required=False, help="Graph model to train.", default="PAiNN.make_crystal_model") + default="hyper/hyper_esol.py") +parser.add_argument("--category", required=False, help="Graph model to train.", default="PAiNN") parser.add_argument("--model", required=False, help="Graph model to train.", default=None) parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None) parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)