Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 15, 2023
1 parent f8de6d5 commit 1c678a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 14 additions & 0 deletions kgcnn/literature/AttentiveFP/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ def make_model(inputs: list = None,
r"""Make `AttentiveFP <https://doi.org/10.1021/acs.jmedchem.9b00959>`_ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.AttentiveFP.model_default`.
Model inputs:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or id tensors following the template below:
%s
Model outputs:
The standard output template:
%s
Args:
inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition.
cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` .
Expand Down Expand Up @@ -141,3 +152,6 @@ def set_scale(*args, **kwargs):
setattr(model, "set_scale", set_scale)

return model


make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__)
4 changes: 3 additions & 1 deletion kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

def template_cast_output(model_outputs,
output_embedding, output_tensor_type, input_tensor_type, cast_disjoint_kwargs):
"""TODO"""

out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_outputs

# Output embedding choice
Expand Down Expand Up @@ -55,7 +57,7 @@ def template_cast_list_input(model_inputs,
has_graph_state: Union[int, bool] = False,
has_crystal_input: Union[int, bool] = False,
return_sub_id: bool = True):

"""TODO"""
standard_inputs = [x for x in model_inputs]

batched_nodes = []
Expand Down

0 comments on commit 1c678a9

Please sign in to comment.