|
| 1 | +# import keras_core as ks |
| 2 | +from ._model import model_disjoint |
| 3 | +from kgcnn.models.utils import update_model_kwargs |
| 4 | +from kgcnn.layers.scale import get as get_scaler |
| 5 | +from kgcnn.layers.modules import Input |
| 6 | +from kgcnn.models.casting import template_cast_output, template_cast_list_input |
| 7 | +from keras.backend import backend as backend_to_use |
| 8 | +from kgcnn.ops.activ import * |
| 9 | + |
| 10 | +# Keep track of model version from commit date in literature. |
| 11 | +# To be updated if model is changed in a significant way. |
| 12 | +__model_version__ = "2023-10-12" |
| 13 | + |
| 14 | +# Supported backends |
| 15 | +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] |
| 16 | +if backend_to_use() not in __kgcnn_model_backend_supported__: |
| 17 | + raise NotImplementedError("Backend '%s' for model 'rGIN' is not supported." % backend_to_use()) |
| 18 | + |
| 19 | +# Implementation of rGIN in `keras` from paper: |
| 20 | +# Random Features Strengthen Graph Neural Networks |
| 21 | +# Ryoma Sato, Makoto Yamada, Hisashi Kashima |
| 22 | +# https://arxiv.org/abs/2002.03155 |
| 23 | + |
| 24 | +model_default = { |
| 25 | + "name": "rGIN", |
| 26 | + "inputs": [ |
| 27 | + {"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True}, |
| 28 | + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True} |
| 29 | + ], |
| 30 | + "input_tensor_type": "padded", |
| 31 | + "cast_disjoint_kwargs": {}, |
| 32 | + "input_embedding": None, # deprecated |
| 33 | + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, |
| 34 | + "gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"], |
| 35 | + "use_normalization": True, "normalization_technique": "graph_batch"}, |
| 36 | + "rgin_args": {"random_range": 100}, |
| 37 | + "depth": 3, "dropout": 0.0, "verbose": 10, |
| 38 | + "last_mlp": {"use_bias": [True, True, True], "units": [64, 64, 64], |
| 39 | + "activation": ["relu", "relu", "linear"]}, |
| 40 | + "output_embedding": 'graph', |
| 41 | + "output_mlp": {"use_bias": True, "units": 1, |
| 42 | + "activation": "softmax"}, |
| 43 | + "output_to_tensor": None, # deprecated |
| 44 | + "output_tensor_type": "padded", |
| 45 | + "output_scaling": None, |
| 46 | +} |
| 47 | + |
| 48 | + |
| 49 | +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) |
| 50 | +def make_model(inputs: list = None, |
| 51 | + input_tensor_type: str = None, |
| 52 | + cast_disjoint_kwargs: dict = None, |
| 53 | + input_embedding: dict = None, # noqa |
| 54 | + input_node_embedding: dict = None, |
| 55 | + depth: int = None, |
| 56 | + rgin_args: dict = None, |
| 57 | + gin_mlp: dict = None, |
| 58 | + last_mlp: dict = None, |
| 59 | + dropout: float = None, |
| 60 | + name: str = None, # noqa |
| 61 | + verbose: int = None, # noqa |
| 62 | + output_embedding: str = None, |
| 63 | + output_to_tensor: bool = None, # noqa |
| 64 | + output_mlp: dict = None, |
| 65 | + output_scaling: dict = None, |
| 66 | + output_tensor_type: str = None, |
| 67 | + ): |
| 68 | + r"""Make `rGIN <https://arxiv.org/abs/2002.03155>`__ graph network via functional API. |
| 69 | + Default parameters can be found in :obj:`kgcnn.literature.rGIN.model_default` . |
| 70 | +
|
| 71 | +
|
| 72 | +
|
| 73 | + Args: |
| 74 | + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. |
| 75 | + input_tensor_type (str): Input type of graph tensor. Default is "padded". |
| 76 | + cast_disjoint_kwargs (dict): Dictionary of arguments for castin layers. |
| 77 | + input_embedding (dict): Deprecated in favour of input_node_embedding etc. |
| 78 | + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. |
| 79 | + depth (int): Number of graph embedding units or depth of the network. |
| 80 | + rgin_args (dict): Dictionary of layer arguments unpacked in :obj:`GIN` convolutional layer. |
| 81 | + gin_mlp (dict): Dictionary of layer arguments unpacked in :obj:`MLP` for convolutional layer. |
| 82 | + last_mlp (dict): Dictionary of layer arguments unpacked in last :obj:`MLP` layer before output or pooling. |
| 83 | + dropout (float): Dropout to use. |
| 84 | + name (str): Name of the model. |
| 85 | + verbose (int): Level of print output. |
| 86 | + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". |
| 87 | + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . |
| 88 | + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. |
| 89 | + Defines number of model outputs and activation. |
| 90 | + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. |
| 91 | + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". |
| 92 | +
|
| 93 | + Returns: |
| 94 | + :obj:`keras.models.Model` |
| 95 | + """ |
| 96 | + # Make input |
| 97 | + # Make input |
| 98 | + model_inputs = [Input(**x) for x in inputs] |
| 99 | + |
| 100 | + disjoint_inputs = template_cast_list_input( |
| 101 | + model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, |
| 102 | + mask_assignment=[0, 1], |
| 103 | + index_assignment=[None, 0] |
| 104 | + ) |
| 105 | + |
| 106 | + n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs |
| 107 | + |
| 108 | + # Wrapping disjoint model. |
| 109 | + out = model_disjoint( |
| 110 | + [n, disjoint_indices, batch_id_node, count_nodes], |
| 111 | + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, |
| 112 | + input_node_embedding=input_node_embedding, |
| 113 | + gin_mlp=gin_mlp, |
| 114 | + depth=depth, |
| 115 | + rgin_args=rgin_args, |
| 116 | + last_mlp=last_mlp, |
| 117 | + output_mlp=output_mlp, |
| 118 | + output_embedding=output_embedding, |
| 119 | + dropout=dropout |
| 120 | + ) |
| 121 | + |
| 122 | + if output_scaling is not None: |
| 123 | + scaler = get_scaler(output_scaling["name"])(**output_scaling) |
| 124 | + out = scaler(out) |
| 125 | + |
| 126 | + # Output embedding choice |
| 127 | + out = template_cast_output( |
| 128 | + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], |
| 129 | + output_embedding=output_embedding, output_tensor_type=output_tensor_type, |
| 130 | + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs |
| 131 | + ) |
| 132 | + |
| 133 | + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) |
| 134 | + model.__kgcnn_model_version__ = __model_version__ |
| 135 | + |
| 136 | + if output_scaling is not None: |
| 137 | + def set_scale(*args, **kwargs): |
| 138 | + scaler.set_scale(*args, **kwargs) |
| 139 | + |
| 140 | + setattr(model, "set_scale", set_scale) |
| 141 | + return model |
| 142 | + |
| 143 | + |
| 144 | +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) |
0 commit comments