Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 16, 2023
1 parent d28beb7 commit 934e718
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 18 deletions.
9 changes: 9 additions & 0 deletions kgcnn/layers/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def __init__(self, pooling_method: str = "scatter_sum", axis=0, **kwargs):
axis (int): Axis to aggregate. Default is 0.
"""
super(Aggregate, self).__init__(**kwargs)
# Shorthand notation
if pooling_method == "sum":
pooling_method = "scatter_sum"
if pooling_method == "max":
pooling_method = "scatter_max"
if pooling_method == "min":
pooling_method = "scatter_min"
if pooling_method == "mean":
pooling_method = "scatter_mean"
self.pooling_method = pooling_method
self.axis = axis
if axis != 0:
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/AttentiveFP/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def make_model(inputs: list = None,
r"""Make `AttentiveFP <https://doi.org/10.1021/acs.jmedchem.9b00959>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.AttentiveFP.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or id tensors following the template below:
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/DMPNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def make_model(name: str = None,
r"""Make `DMPNN <https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00237>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.DMPNN.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, reverse_indices, (graph_state), ...]`
with '...' indicating mask or id tensors following the template below.
Expand All @@ -93,7 +93,7 @@ def make_model(name: str = None,
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def make_model(inputs: list = None,
r"""Make `GAT <https://arxiv.org/abs/1710.10903>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GAT.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or ID tensors following the template below:
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GATv2/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def make_model(inputs: list = None,
r"""Make `GATv2 <https://arxiv.org/abs/2105.14491>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GATv2.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or ID tensors following the template below:
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GCN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def make_model(inputs: list = None,
r"""Make `GCN <https://arxiv.org/abs/1609.02907>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GCN.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or ID tensors following the template below.
Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix.
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def make_model(inputs: list = None,
r"""Make `GIN <https://arxiv.org/abs/1810.00826>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GIN.model_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edge_indices, ...]`
with '...' indicating mask or ID tensors following the template below:
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GraphSAGE/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def make_model(inputs: list = None,
r"""Make `GraphSAGE <http://arxiv.org/abs/1706.02216>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.GraphSAGE.model_default` .
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or ID tensors following the template below.
Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix.
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/PAiNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def make_crystal_model(inputs: list = None,
r"""Make `PAiNN <https://arxiv.org/pdf/2102.03150.pdf>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.PAiNN.model_crystal_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, coordinates, edge_indices, image_translation, lattice, ...]`
with '...' indicating mask or ID tensors following the template below.
Expand All @@ -238,7 +238,7 @@ def make_crystal_model(inputs: list = None,
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/Schnet/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ def make_crystal_model(inputs: list = None,
r"""Make `SchNet <https://arxiv.org/abs/1706.08566>`__ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.Schnet.model_crystal_default`.
Model inputs:
**Model inputs**:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, coordinates, edge_indices, image_translation, lattice, ...]`
with '...' indicating mask or ID tensors following the template below.
%s
Model outputs:
**Model outputs**:
The standard output template:
%s
Expand Down

0 comments on commit 934e718

Please sign in to comment.