Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 16, 2023
1 parent 2753e61 commit 86b3670
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def template_cast_list_input(model_inputs,
- angle_indices (Tensor): Index list for angles of shape `(batch, K, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, M, 3)` .
Shape is `(batch, M, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- node_mask (Tensor): Mask for padded nodes of shape `(batch, N)` .
- edge_mask (Tensor): Mask for padded edges of shape `(batch, M)` .
Expand All @@ -93,7 +93,7 @@ def template_cast_list_input(model_inputs,
- angle_count (Tensor): Total number of angle if padding is used of shape `(batch, )` .
Ragged or Jagged Inputs:
list: obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice]`
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice]`
- nodes (RaggedTensor): Node attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
Expand All @@ -105,11 +105,11 @@ def template_cast_list_input(model_inputs,
- angle_indices (RaggedTensor): Index list for angles of shape `(batch, None, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (RaggedTensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, None, 3)` .
Shape is `(batch, None, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
Disjoint Input:
list: obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,
graph_id_node, graph_id_edge, graph_id_angle, nodes_id, edges_id, angle_id, nodes_count, edges_count,
angles_count]`
Expand All @@ -120,7 +120,7 @@ def template_cast_list_input(model_inputs,
- angle_indices (Tensor): Index list for angles of shape `(2, [K])` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `([M], 3)` .
Shape is `([M], 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- graph_id_node (Tensor): ID tensor of graph assignment in disjoint graph of shape `([N], )` .
- graph_id_edge (Tensor): ID tensor of graph assignment in disjoint graph of shape `([M], )` .
Expand Down

0 comments on commit 86b3670

Please sign in to comment.