diff --git a/kgcnn/layers/modules.py b/kgcnn/layers/modules.py index 556d835d..4d496927 100644 --- a/kgcnn/layers/modules.py +++ b/kgcnn/layers/modules.py @@ -60,3 +60,26 @@ def get_config(self): config = super(ExpandDims, self).get_config() config.update({"axis": self.axis}) return config + + +def Input( + shape=None, + batch_size=None, + dtype=None, + sparse=None, + batch_shape=None, + name=None, + tensor=None, + ragged=None + ): + + layer = ks.layers.InputLayer( + shape=shape, + batch_size=batch_size, + dtype=dtype, + sparse=sparse, + batch_shape=batch_shape, + name=name, + input_tensor=tensor, + ) + return layer.output \ No newline at end of file diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index 8823fa95..a2c29ff5 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -6,6 +6,7 @@ from kgcnn.layers.scale import get as get_scaler from kgcnn.models.utils import update_model_kwargs from keras_core.backend import backend as backend_to_use +from kgcnn.layers.modules import Input from ._model import model_disjoint # Keep track of model version from commit date in literature. @@ -132,7 +133,7 @@ def make_model(name: str = None, :obj:`keras.models.Model` """ # Make input - model_inputs = [ks.layers.Input(**x) for x in inputs] + model_inputs = [Input(**x) for x in inputs] if input_tensor_type in ["padded", "masked"]: batched_nodes, batched_edges, batched_indices, batched_reverse, total_nodes, total_edges = model_inputs[:6]