Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 14, 2023
1 parent 05f7c9d commit 34db772
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 83 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ original implementations (with proper licencing).
* **[GAT](kgcnn/literature/GAT)**: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) by Veličković et al. (2018)
* **[GraphSAGE](kgcnn/literature/GraphSAGE)**: [Inductive Representation Learning on Large Graphs](http://arxiv.org/abs/1706.02216) by Hamilton et al. (2017)
* **[GIN](kgcnn/literature/GIN)**: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826) by Xu et al. (2019)
* **[GNNExplainer](kgcnn/literature/GNNExplain)**: [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894) by Ying et al. (2019)

<details>
<summary> ... and many more <b>(click to expand)</b>.</summary>
Expand All @@ -168,6 +169,7 @@ original implementations (with proper licencing).
* **[AttentiveFP](kgcnn/literature/AttentiveFP)**: [Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959) by Xiong et al. (2019)
* **[PAiNN](kgcnn/literature/PAiNN)**: [Equivariant message passing for the prediction of tensorial properties and molecular spectra](https://arxiv.org/pdf/2102.03150.pdf) by Schütt et al. (2020)

To be completed ...
</details>


Expand Down
21 changes: 15 additions & 6 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,33 +386,42 @@ def call(self, inputs: list, **kwargs):
r"""Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format.
Args:
inputs (list): List of `[(target), attr, graph_id_attr, attr_id, attr_counts]` ,
inputs (list): List of `[attr, graph_id_attr, (attr_id), attr_counts]` ,
- target (Tensor, Optional): Target tensor of shape equal to desired output of
shape `(batch, N, F, ...)` .
Only first dimension is required for making new tensor e.g. of shape `(batch,)` .
- attr (Tensor): Features are represented by a keras tensor of shape `([N], F, ...)` ,
where N denotes the number of nodes or edges.
- graph_id_attr (Tensor): ID tensor of graph assignment in disjoint graph of shape `([N], )` .
- attr_id (Tensor): The ID-tensor to assign each node to its respective graph of shape `([N], )` .
- attr_id (Tensor, optional): The ID-tensor to assign each node to its respective graph
of shape `([N], )` . For padded disjoint graphs this is required.
- attr_counts (Tensor): Tensor of lengths for each graph of shape `(batch, )` .
Returns:
Tensor: Batched output tensor of node or edge attributes of shape `(batch, N, F, ...)` .
"""
attr, graph_id_attr, attr_id, attr_len = inputs
if len(inputs) == 4:
attr, graph_id_attr, attr_id, attr_len = inputs
else:
attr, graph_id_attr, attr_len = inputs
attr_id = None

if self.static_output_shape is not None:
target_shape = (ops.shape(attr_len)[0], self.static_output_shape[0])
else:
target_shape = (ops.shape(attr_len)[0], ops.amax(attr_len))

if not self.padded_disjoint:
if attr_id is None:
attr_id = ops.arange(0, ops.shape(graph_id_attr)[0], dtype=graph_id_attr.dtype)
attr_splits = _pad_left(attr_len[1:]-attr_len[:1])
attr_id = attr_id - repeat_static_length(attr_splits[:1], attr_len, ops.shape(graph_id_attr)[0])
output_shape = [target_shape[0]*target_shape[1]] + list(ops.shape(attr)[1:])
indices = graph_id_attr*ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast(
attr_id, dtype=graph_id_attr.dtype)
out = scatter_reduce_sum(indices, attr, output_shape)
out = ops.reshape(out, list(target_shape[:2]) + list(ops.shape(attr)[1:]))
else:
if attr_id is None:
raise ValueError("Require sub-graph IDs in addition to batch IDs for padded disjoint graphs.")
output_shape = [(target_shape[0]+1)*target_shape[1]] + list(ops.shape(attr)[1:])
indices = graph_id_attr * ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast(
attr_id, dtype=graph_id_attr.dtype)
Expand Down
44 changes: 13 additions & 31 deletions kgcnn/literature/DMPNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes)
from kgcnn.layers.scale import get as get_scaler
from kgcnn.models.utils import update_model_kwargs
from kgcnn.models.casting import template_cast_output
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from keras_core.backend import backend as backend_to_use
from kgcnn.layers.modules import Input
from ._model import model_disjoint
Expand Down Expand Up @@ -138,40 +138,22 @@ def make_model(name: str = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

if input_tensor_type in ["padded", "masked"]:
if use_graph_state:
batched_nodes, batched_edges, graph_state, batched_indices, batched_reverse, total_nodes, total_edges = model_inputs[:7]
gs = CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(graph_state)
else:
batched_nodes, batched_edges, batched_indices, batched_reverse, total_nodes, total_edges = model_inputs[:6]
gs = None
n, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges])
ed, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_edges, total_edges])
_, ed_pairs, _, _, _, _, _, _ = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_indices, batched_reverse, count_edges, count_edges])
elif input_tensor_type in ["ragged", "jagged"]:
if use_graph_state:
batched_nodes, batched_edges, graph_state, batched_indices, batched_reverse = model_inputs
gs = graph_state
else:
batched_nodes, batched_edges, batched_indices, batched_reverse = model_inputs
gs = None
n, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastRaggedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices])
ed, _, _, _ = CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(batched_edges)
_, ed_pairs, _, _, _, _, _, _ = CastRaggedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_indices, batched_reverse])
di = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=True, has_edges=True, has_graph_state=use_graph_state,
has_angle_indices=True, # Treat reverse indices as edge indices
has_edge_indices=True
)

if use_graph_state:
n, ed, edi, e_pairs, gs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di
else:
if use_graph_state:
n, ed, gs, edi, ed_pairs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_inputs
else:
n, ed, edi, ed_pairs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_inputs
gs = None
n, ed, edi, e_pairs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di
gs = None

# Wrapping disjoint model.
out = model_disjoint(
[n, ed, edi, batch_id_node, ed_pairs, count_nodes, gs],
[n, ed, edi, batch_id_node, e_pairs, count_nodes, gs],
use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2,
use_graph_embedding=len(inputs[7]["shape"]) < 1 if use_graph_state else False,
input_node_embedding=input_node_embedding,
Expand Down
6 changes: 3 additions & 3 deletions kgcnn/literature/GAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kgcnn.layers.scale import get as get_scaler
from kgcnn.layers.modules import Input
from keras_core.backend import backend as backend_to_use
from kgcnn.models.casting import template_cast_input, template_cast_output
from kgcnn.models.casting import template_cast_list_input, template_cast_output
from kgcnn.ops.activ import *

# Keep track of model version from commit date in literature.
Expand Down Expand Up @@ -114,10 +114,10 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_model_inputs = template_cast_input(
di_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs
n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs

# Wrapping disjoint model.
out = model_disjoint(
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GATv2/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kgcnn.models.utils import update_model_kwargs
from kgcnn.layers.scale import get as get_scaler
from kgcnn.layers.modules import Input
from kgcnn.models.casting import template_cast_input, template_cast_output
from kgcnn.models.casting import template_cast_list_input, template_cast_output
from keras_core.backend import backend as backend_to_use
from kgcnn.ops.activ import *

Expand Down Expand Up @@ -116,7 +116,7 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_model_inputs = template_cast_input(
disjoint_model_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GCN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._model import model_disjoint
from kgcnn.layers.modules import Input
from kgcnn.models.utils import update_model_kwargs
from kgcnn.models.casting import template_cast_output, template_cast_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from keras_core.backend import backend as backend_to_use

# from keras_core.layers import Activation
Expand Down Expand Up @@ -118,7 +118,7 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_model_inputs = template_cast_input(
disjoint_model_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs
Expand Down
10 changes: 6 additions & 4 deletions kgcnn/literature/GIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kgcnn.models.utils import update_model_kwargs
from kgcnn.layers.scale import get as get_scaler
from kgcnn.layers.modules import Input
from kgcnn.models.casting import template_cast_output, template_cast_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from keras_core.backend import backend as backend_to_use
from kgcnn.ops.activ import *

Expand Down Expand Up @@ -111,8 +111,10 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_inputs = template_cast_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, has_edges=False)
disjoint_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False
)

n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs

Expand Down Expand Up @@ -237,7 +239,7 @@ def make_model_edge(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_inputs = template_cast_input(
disjoint_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/GraphSAGE/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kgcnn.layers.modules import Input
from kgcnn.models.utils import update_model_kwargs
from kgcnn.layers.scale import get as get_scaler
from kgcnn.models.casting import template_cast_output, template_cast_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from keras_core.backend import backend as backend_to_use

# Keep track of model version from commit date in literature.
Expand Down Expand Up @@ -124,7 +124,7 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_model_inputs = template_cast_input(
disjoint_model_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs
Expand Down
10 changes: 5 additions & 5 deletions kgcnn/literature/PAiNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kgcnn.models.utils import update_model_kwargs
from keras_core.backend import backend as backend_to_use
from kgcnn.layers.scale import get as get_scaler
from kgcnn.models.casting import template_cast_output, template_cast_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from ._model import model_disjoint, model_disjoint_crystal

# To be updated if model is changed in a significant way.
Expand Down Expand Up @@ -117,10 +117,10 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_inputs = template_cast_input(model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=True,
has_coordinates_as_edges=True)
disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=2+int(has_equivariant_input),
has_edges=False)

if not has_equivariant_input:
z, x, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs
Expand Down
13 changes: 7 additions & 6 deletions kgcnn/literature/Schnet/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kgcnn.layers.scale import get as get_scaler
from ._model import model_disjoint, model_disjoint_crystal
from kgcnn.layers.modules import Input
from kgcnn.models.casting import template_cast_output, template_cast_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from kgcnn.models.utils import update_model_kwargs
from keras_core.backend import backend as backend_to_use

Expand All @@ -29,7 +29,7 @@
model_default = {
"name": "Schnet",
"inputs": [
{"shape": (None,), "name": "node_attributes", "dtype": "float32"},
{"shape": (None,), "name": "node_number", "dtype": "int64"},
{"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"},
{"shape": (None, 2), "name": "edge_indices", "dtype": "int64"},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
Expand Down Expand Up @@ -133,15 +133,16 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

disjoint_inputs = template_cast_input(model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=(not make_distance), has_nodes=1 + int(make_distance))
disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=(not make_distance), has_nodes=1 + int(make_distance))

n, x, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs

out = model_disjoint(
[n, x, disjoint_indices, batch_id_node, count_nodes],
use_node_embedding=len(inputs[0]['shape']) < 2, input_node_embedding=input_node_embedding,
use_node_embedding=len(inputs[0]['shape']) < 2,
input_node_embedding=input_node_embedding,
make_distance=make_distance, expand_distance=expand_distance, gauss_args=gauss_args,
interaction_args=interaction_args, node_pooling_args=node_pooling_args, depth=depth,
last_mlp=last_mlp, output_embedding=output_embedding, use_output_mlp=use_output_mlp,
Expand Down
Loading

0 comments on commit 34db772

Please sign in to comment.