From 34db772c87ebdcea5d709a62405cf9ea63772d12 Mon Sep 17 00:00:00 2001 From: PatReis Date: Tue, 14 Nov 2023 12:28:52 +0100 Subject: [PATCH] update for keras 3.0 --- README.md | 2 + kgcnn/layers/casting.py | 21 +++-- kgcnn/literature/DMPNN/_make.py | 44 ++++------- kgcnn/literature/GAT/_make.py | 6 +- kgcnn/literature/GATv2/_make.py | 4 +- kgcnn/literature/GCN/_make.py | 4 +- kgcnn/literature/GIN/_make.py | 10 ++- kgcnn/literature/GraphSAGE/_make.py | 4 +- kgcnn/literature/PAiNN/_make.py | 10 +-- kgcnn/literature/Schnet/_make.py | 13 ++-- kgcnn/models/casting.py | 115 ++++++++++++++++++++++------ 11 files changed, 150 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index f6d95668..c8e82558 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ original implementations (with proper licencing). * **[GAT](kgcnn/literature/GAT)**: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) by Veličković et al. (2018) * **[GraphSAGE](kgcnn/literature/GraphSAGE)**: [Inductive Representation Learning on Large Graphs](http://arxiv.org/abs/1706.02216) by Hamilton et al. (2017) * **[GIN](kgcnn/literature/GIN)**: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826) by Xu et al. (2019) +* **[GNNExplainer](kgcnn/literature/GNNExplain)**: [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894) by Ying et al. (2019)
... and many more (click to expand). @@ -168,6 +169,7 @@ original implementations (with proper licencing). * **[AttentiveFP](kgcnn/literature/AttentiveFP)**: [Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959) by Xiong et al. (2019) * **[PAiNN](kgcnn/literature/PAiNN)**: [Equivariant message passing for the prediction of tensorial properties and molecular spectra](https://arxiv.org/pdf/2102.03150.pdf) by Schütt et al. (2020) +To be completed ...
diff --git a/kgcnn/layers/casting.py b/kgcnn/layers/casting.py index 03a20ed6..325d9026 100644 --- a/kgcnn/layers/casting.py +++ b/kgcnn/layers/casting.py @@ -386,33 +386,42 @@ def call(self, inputs: list, **kwargs): r"""Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format. Args: - inputs (list): List of `[(target), attr, graph_id_attr, attr_id, attr_counts]` , + inputs (list): List of `[attr, graph_id_attr, (attr_id), attr_counts]` , - - target (Tensor, Optional): Target tensor of shape equal to desired output of - shape `(batch, N, F, ...)` . - Only first dimension is required for making new tensor e.g. of shape `(batch,)` . - attr (Tensor): Features are represented by a keras tensor of shape `([N], F, ...)` , where N denotes the number of nodes or edges. - graph_id_attr (Tensor): ID tensor of graph assignment in disjoint graph of shape `([N], )` . - - attr_id (Tensor): The ID-tensor to assign each node to its respective graph of shape `([N], )` . + - attr_id (Tensor, optional): The ID-tensor to assign each node to its respective graph + of shape `([N], )` . For padded disjoint graphs this is required. - attr_counts (Tensor): Tensor of lengths for each graph of shape `(batch, )` . Returns: Tensor: Batched output tensor of node or edge attributes of shape `(batch, N, F, ...)` . """ - attr, graph_id_attr, attr_id, attr_len = inputs + if len(inputs) == 4: + attr, graph_id_attr, attr_id, attr_len = inputs + else: + attr, graph_id_attr, attr_len = inputs + attr_id = None + if self.static_output_shape is not None: target_shape = (ops.shape(attr_len)[0], self.static_output_shape[0]) else: target_shape = (ops.shape(attr_len)[0], ops.amax(attr_len)) if not self.padded_disjoint: + if attr_id is None: + attr_id = ops.arange(0, ops.shape(graph_id_attr)[0], dtype=graph_id_attr.dtype) + attr_splits = _pad_left(attr_len[1:]-attr_len[:1]) + attr_id = attr_id - repeat_static_length(attr_splits[:1], attr_len, ops.shape(graph_id_attr)[0]) output_shape = [target_shape[0]*target_shape[1]] + list(ops.shape(attr)[1:]) indices = graph_id_attr*ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast( attr_id, dtype=graph_id_attr.dtype) out = scatter_reduce_sum(indices, attr, output_shape) out = ops.reshape(out, list(target_shape[:2]) + list(ops.shape(attr)[1:])) else: + if attr_id is None: + raise ValueError("Require sub-graph IDs in addition to batch IDs for padded disjoint graphs.") output_shape = [(target_shape[0]+1)*target_shape[1]] + list(ops.shape(attr)[1:]) indices = graph_id_attr * ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast( attr_id, dtype=graph_id_attr.dtype) diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index 0c27b76b..8b2bc483 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -5,7 +5,7 @@ CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) from kgcnn.layers.scale import get as get_scaler from kgcnn.models.utils import update_model_kwargs -from kgcnn.models.casting import template_cast_output +from kgcnn.models.casting import template_cast_output, template_cast_list_input from keras_core.backend import backend as backend_to_use from kgcnn.layers.modules import Input from ._model import model_disjoint @@ -138,40 +138,22 @@ def make_model(name: str = None, # Make input model_inputs = [Input(**x) for x in inputs] - if input_tensor_type in ["padded", "masked"]: - if use_graph_state: - batched_nodes, batched_edges, graph_state, batched_indices, batched_reverse, total_nodes, total_edges = model_inputs[:7] - gs = CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(graph_state) - else: - batched_nodes, batched_edges, batched_indices, batched_reverse, total_nodes, total_edges = model_inputs[:6] - gs = None - n, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges]) - ed, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_edges, total_edges]) - _, ed_pairs, _, _, _, _, _, _ = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_indices, batched_reverse, count_edges, count_edges]) - elif input_tensor_type in ["ragged", "jagged"]: - if use_graph_state: - batched_nodes, batched_edges, graph_state, batched_indices, batched_reverse = model_inputs - gs = graph_state - else: - batched_nodes, batched_edges, batched_indices, batched_reverse = model_inputs - gs = None - n, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastRaggedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes, batched_indices]) - ed, _, _, _ = CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(batched_edges) - _, ed_pairs, _, _, _, _, _, _ = CastRaggedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_indices, batched_reverse]) + di = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=True, has_edges=True, has_graph_state=use_graph_state, + has_angle_indices=True, # Treat reverse indices as edge indices + has_edge_indices=True + ) + + if use_graph_state: + n, ed, edi, e_pairs, gs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di else: - if use_graph_state: - n, ed, gs, edi, ed_pairs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_inputs - else: - n, ed, edi, ed_pairs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_inputs - gs = None + n, ed, edi, e_pairs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di + gs = None # Wrapping disjoint model. out = model_disjoint( - [n, ed, edi, batch_id_node, ed_pairs, count_nodes, gs], + [n, ed, edi, batch_id_node, e_pairs, count_nodes, gs], use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, use_graph_embedding=len(inputs[7]["shape"]) < 1 if use_graph_state else False, input_node_embedding=input_node_embedding, diff --git a/kgcnn/literature/GAT/_make.py b/kgcnn/literature/GAT/_make.py index c0c6f443..da354fc3 100644 --- a/kgcnn/literature/GAT/_make.py +++ b/kgcnn/literature/GAT/_make.py @@ -3,7 +3,7 @@ from kgcnn.layers.scale import get as get_scaler from kgcnn.layers.modules import Input from keras_core.backend import backend as backend_to_use -from kgcnn.models.casting import template_cast_input, template_cast_output +from kgcnn.models.casting import template_cast_list_input, template_cast_output from kgcnn.ops.activ import * # Keep track of model version from commit date in literature. @@ -114,10 +114,10 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_input( + di_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) - n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs # Wrapping disjoint model. out = model_disjoint( diff --git a/kgcnn/literature/GATv2/_make.py b/kgcnn/literature/GATv2/_make.py index 23a48f91..9fa81149 100644 --- a/kgcnn/literature/GATv2/_make.py +++ b/kgcnn/literature/GATv2/_make.py @@ -2,7 +2,7 @@ from kgcnn.models.utils import update_model_kwargs from kgcnn.layers.scale import get as get_scaler from kgcnn.layers.modules import Input -from kgcnn.models.casting import template_cast_input, template_cast_output +from kgcnn.models.casting import template_cast_list_input, template_cast_output from keras_core.backend import backend as backend_to_use from kgcnn.ops.activ import * @@ -116,7 +116,7 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_input( + disjoint_model_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs diff --git a/kgcnn/literature/GCN/_make.py b/kgcnn/literature/GCN/_make.py index ff059d01..3f39d7fb 100644 --- a/kgcnn/literature/GCN/_make.py +++ b/kgcnn/literature/GCN/_make.py @@ -7,7 +7,7 @@ from ._model import model_disjoint from kgcnn.layers.modules import Input from kgcnn.models.utils import update_model_kwargs -from kgcnn.models.casting import template_cast_output, template_cast_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input from keras_core.backend import backend as backend_to_use # from keras_core.layers import Activation @@ -118,7 +118,7 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_input( + disjoint_model_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs diff --git a/kgcnn/literature/GIN/_make.py b/kgcnn/literature/GIN/_make.py index 63fa9bb0..fabcf1a1 100644 --- a/kgcnn/literature/GIN/_make.py +++ b/kgcnn/literature/GIN/_make.py @@ -7,7 +7,7 @@ from kgcnn.models.utils import update_model_kwargs from kgcnn.layers.scale import get as get_scaler from kgcnn.layers.modules import Input -from kgcnn.models.casting import template_cast_output, template_cast_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input from keras_core.backend import backend as backend_to_use from kgcnn.ops.activ import * @@ -111,8 +111,10 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, has_edges=False) + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=False + ) n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs @@ -237,7 +239,7 @@ def make_model_edge(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_input( + disjoint_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs diff --git a/kgcnn/literature/GraphSAGE/_make.py b/kgcnn/literature/GraphSAGE/_make.py index 71c873c4..722f1bea 100644 --- a/kgcnn/literature/GraphSAGE/_make.py +++ b/kgcnn/literature/GraphSAGE/_make.py @@ -7,7 +7,7 @@ from kgcnn.layers.modules import Input from kgcnn.models.utils import update_model_kwargs from kgcnn.layers.scale import get as get_scaler -from kgcnn.models.casting import template_cast_output, template_cast_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input from keras_core.backend import backend as backend_to_use # Keep track of model version from commit date in literature. @@ -124,7 +124,7 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_model_inputs = template_cast_input( + disjoint_model_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_model_inputs diff --git a/kgcnn/literature/PAiNN/_make.py b/kgcnn/literature/PAiNN/_make.py index 89537f14..1c7f414f 100644 --- a/kgcnn/literature/PAiNN/_make.py +++ b/kgcnn/literature/PAiNN/_make.py @@ -4,7 +4,7 @@ from kgcnn.models.utils import update_model_kwargs from keras_core.backend import backend as backend_to_use from kgcnn.layers.scale import get as get_scaler -from kgcnn.models.casting import template_cast_output, template_cast_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input from ._model import model_disjoint, model_disjoint_crystal # To be updated if model is changed in a significant way. @@ -117,10 +117,10 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=True, - has_coordinates_as_edges=True) + disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=2+int(has_equivariant_input), + has_edges=False) if not has_equivariant_input: z, x, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs diff --git a/kgcnn/literature/Schnet/_make.py b/kgcnn/literature/Schnet/_make.py index 5aceb10c..6d6d1e38 100644 --- a/kgcnn/literature/Schnet/_make.py +++ b/kgcnn/literature/Schnet/_make.py @@ -6,7 +6,7 @@ from kgcnn.layers.scale import get as get_scaler from ._model import model_disjoint, model_disjoint_crystal from kgcnn.layers.modules import Input -from kgcnn.models.casting import template_cast_output, template_cast_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input from kgcnn.models.utils import update_model_kwargs from keras_core.backend import backend as backend_to_use @@ -29,7 +29,7 @@ model_default = { "name": "Schnet", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, @@ -133,15 +133,16 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), has_nodes=1 + int(make_distance)) + disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=(not make_distance), has_nodes=1 + int(make_distance)) n, x, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs out = model_disjoint( [n, x, disjoint_indices, batch_id_node, count_nodes], - use_node_embedding=len(inputs[0]['shape']) < 2, input_node_embedding=input_node_embedding, + use_node_embedding=len(inputs[0]['shape']) < 2, + input_node_embedding=input_node_embedding, make_distance=make_distance, expand_distance=expand_distance, gauss_args=gauss_args, interaction_args=interaction_args, node_pooling_args=node_pooling_args, depth=depth, last_mlp=last_mlp, output_embedding=output_embedding, use_output_mlp=use_output_mlp, diff --git a/kgcnn/models/casting.py b/kgcnn/models/casting.py index 538aab9d..6824f6f9 100644 --- a/kgcnn/models/casting.py +++ b/kgcnn/models/casting.py @@ -9,7 +9,6 @@ def template_cast_output(model_outputs, output_embedding, output_tensor_type, input_tensor_type, cast_disjoint_kwargs): - out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_outputs # Output embedding choice @@ -45,75 +44,147 @@ def template_cast_output(model_outputs, return out -def template_cast_input(model_inputs, - input_tensor_type, - cast_disjoint_kwargs, - has_nodes: Union[int, bool] = True, - has_edges: Union[int, bool] = True, - has_graph_state: Union[int, bool] = False): +def template_cast_list_input(model_inputs, + input_tensor_type, + cast_disjoint_kwargs, + has_nodes: Union[int, bool] = True, + has_edges: Union[int, bool] = True, + has_angles: Union[int, bool] = False, + has_edge_indices: bool = True, + has_angle_indices: bool = False, + has_graph_state: Union[int, bool] = False, + return_sub_id: bool = True): + + standard_inputs = [x for x in model_inputs] + batched_nodes = [] batched_edges = [] + batched_angles = [] batched_state = [] + batched_indices = [] + batched_angle_indices = [] for i in range(int(has_nodes)): batched_nodes.append(standard_inputs.pop(0)) for i in range(int(has_edges)): batched_edges.append(standard_inputs.pop(0)) + for i in range(int(has_angles)): + batched_angles.append(standard_inputs.pop(0)) + for i in range(int(has_edge_indices)): + batched_indices.append(standard_inputs.pop(0)) + for i in range(int(has_angle_indices)): + batched_angle_indices.append(standard_inputs.pop(0)) for i in range(int(has_graph_state)): batched_state.append(standard_inputs.pop(0)) - batched_indices = standard_inputs.pop(0) - batched_id = standard_inputs disjoint_nodes = [] disjoint_edges = [] disjoint_state = [] + disjoint_angles = [] + disjoint_indices = [] + disjoint_angle_indices = [] + disjoint_id = [] if input_tensor_type in ["padded", "masked"]: - part_nodes, part_edges = batched_id - n, idx, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes.pop(0), batched_indices, part_nodes, part_edges]) - disjoint_indices = [idx] - disjoint_nodes.append(n) - disjoint_id = [batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges] + if int(has_angle_indices) > 0: + part_nodes, part_edges, part_angle = batched_id + else: + part_nodes, part_edges = batched_id + part_angle = None + + for x in batched_indices: + _, idx, batch_id_node, batch_id_edge, node_id, edge_id, len_nodes, len_edges = CastBatchedIndicesToDisjoint( + **cast_disjoint_kwargs)([batched_nodes[0], x, part_nodes, part_edges]) + disjoint_indices.append(idx) + + for x in batched_angle_indices: + _, idx, _, batch_id_ang, _, ang_id, _, len_ang = CastBatchedIndicesToDisjoint( + **cast_disjoint_kwargs)([batched_indices[0], x, part_edges, part_angle]) + disjoint_angle_indices.append(idx) for x in batched_nodes: disjoint_nodes.append( CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_nodes])[0]) + for x in batched_edges: disjoint_edges.append( CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_edges])[0]) + + for x in batched_angles: + disjoint_angles.append( + CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_angle])[0]) + for x in batched_state: disjoint_state.append( CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(x)) elif input_tensor_type in ["ragged", "jagged"]: - n, idx, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastRaggedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes.pop(0), batched_indices]) - disjoint_indices = [idx] - disjoint_nodes.append(n) - disjoint_id = [batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges] + for x in batched_indices: + _, idx, batch_id_node, batch_id_edge, node_id, edge_id, len_nodes, len_edges = CastRaggedIndicesToDisjoint( + **cast_disjoint_kwargs)([batched_nodes[0], x]) + disjoint_indices.append(idx) + + for x in batched_angle_indices: + _, idx, _, batch_id_ang, _, ang_id, _, len_ang = CastRaggedIndicesToDisjoint( + **cast_disjoint_kwargs)([batched_indices[0], x]) + disjoint_angle_indices.append(idx) for x in batched_nodes: disjoint_nodes.append( CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) + for x in batched_edges: disjoint_edges.append( CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) + + for x in batched_angles: + disjoint_angles.append( + CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) + disjoint_state = batched_state else: disjoint_nodes = batched_nodes disjoint_edges = batched_edges - disjoint_indices = [batched_indices] + disjoint_indices = batched_indices disjoint_state = batched_state + disjoint_angle_indices = batched_angle_indices + disjoint_angles = batched_angles + + if input_tensor_type in ["ragged", "jagged", "padded", "masked"]: + disjoint_id.append(batch_id_node) # noqa + disjoint_id.append(batch_id_edge) # noqa + if int(has_angle_indices) > 0: + disjoint_id.append(batch_id_ang) # noqa + if return_sub_id: + disjoint_id.append(node_id) # noqa + disjoint_id.append(edge_id) # noqa + if int(has_angle_indices) > 0: + disjoint_id.append(ang_id) # noqa + disjoint_id.append(len_nodes) # noqa + disjoint_id.append(len_edges) # noqa + if int(has_angle_indices) > 0: + disjoint_id.append(len_ang) # noqa + else: disjoint_id = batched_id - disjoint_model_inputs = disjoint_nodes + disjoint_edges + disjoint_state + disjoint_indices + disjoint_id + disjoint_model_inputs = disjoint_nodes + disjoint_edges + disjoint_angles + disjoint_indices + disjoint_angle_indices + disjoint_state + disjoint_id return disjoint_model_inputs + + +def template_cast_list_crystal_input(model_inputs, + input_tensor_type, + cast_disjoint_kwargs, + has_nodes: Union[int, bool] = True, + has_edges: Union[int, bool] = True, + has_graph_state: Union[int, bool] = False, + has_angle_indices: bool = False, + return_sub_id: bool = False): + pass