Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 15, 2023
1 parent 1c678a9 commit a667d3a
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions kgcnn/literature/GAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,16 @@ def make_model(inputs: list = None,
r"""Make `GAT <https://arxiv.org/abs/1710.10903>`_ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GAT.model_default`.
Inputs:
list: `[node_attributes, edge_attributes, edge_indices, total_nodes, total_edges]`
Model inputs:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or id tensors following the template below:
%s
- node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
- edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
- edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`.
- total_nodes(Tensor): Number of Nodes in graph if not same sized graphs of shape `(batch, )` .
- total_edges(Tensor): Number of Edges in graph if not same sized graphs of shape `(batch, )` .
Model outputs:
The standard output template:
%s
Outputs:
Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
Args:
inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition.
Expand Down Expand Up @@ -150,3 +147,6 @@ def set_scale(*args, **kwargs):
setattr(model, "set_scale", set_scale)

return model


make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__)

0 comments on commit a667d3a

Please sign in to comment.