Skip to content

Commit

Permalink
Update docs. Added further training results.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 4, 2024
1 parent 984f094 commit 762fd88
Show file tree
Hide file tree
Showing 7 changed files with 1,044 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ __pycache__
/notebooks/esol/
/notebooks/DMPNN_ESOL_predict.png
/notebooks/DMPNN_esol_loss.png
/docs/source/GIN_ESOL_predict.png
/docs/source/GIN_esol_loss.png
324 changes: 322 additions & 2 deletions docs/source/layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,335 @@
"id": "35de5004",
"metadata": {},
"source": [
"## Implementaion details"
"## Implementaion details\n",
"\n",
"The following steps that are most representative for GNNs have layers in `kgcnn.layers` ."
]
},
{
"cell_type": "markdown",
"id": "212298b7",
"metadata": {},
"source": [
"TODO"
"#### Casting\n",
"\n",
"Cast batched node and edge indices to a (single) disjoint graph representation of [Pytorch Geometric (PyG)](https://github.com/pyg-team/pytorch_geometric). For PyG a batch of graphs is represented by single graph which contains disjoint sub-graphs,\n",
"and the batch information is passed as batch ID tensor: `graph_id_node` and `graph_id_edge` .\n",
"For keras padded tensors can be used to input into keras models"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "89f9e509-9e27-424f-8b9e-9981a442026a",
"metadata": {},
"outputs": [],
"source": [
"from keras import ops \n",
"nodes = ops.convert_to_tensor([[[0.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]]])\n",
"edges = ops.convert_to_tensor([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], [-1.0, 1.0, 1.0]]])\n",
"edge_indices = ops.convert_to_tensor([[[0, 0], [0, 1], [1, 0], [1, 1]], [[0, 0], [0, 1], [1, 0], [1, 1]]], dtype=\"int64\")\n",
"node_mask = ops.convert_to_tensor([[True, False], [True, True]])\n",
"edge_mask = ops.convert_to_tensor([[True, False, False, False], [True, True, True, False]])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0e488cb8-6aff-47fe-a73c-a3e492026378",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Disjoint index:\n",
" tensor([[0, 1, 1, 2],\n",
" [0, 1, 2, 1]], device='cuda:0')\n",
"Node attributes:\n",
" tensor([[0., 1.],\n",
" [1., 0.],\n",
" [1., 1.]], device='cuda:0')\n",
"Batch ID nodes:\n",
" tensor([0, 1, 1], device='cuda:0')\n"
]
}
],
"source": [
"from kgcnn.layers.casting import CastBatchedIndicesToDisjoint\n",
"disjoint_tensors = CastBatchedIndicesToDisjoint(uses_mask=True)([nodes, edge_indices, node_mask, edge_mask])\n",
"node_attr, disjoint_index, graph_id_node, graph_id_edge, node_id, edge_id, node_count, edge_count = disjoint_tensors\n",
"print(\"Disjoint index:\\n\", disjoint_index)\n",
"print(\"Node attributes:\\n\", node_attr)\n",
"print(\"Batch ID nodes:\\n\", graph_id_node)"
]
},
{
"cell_type": "markdown",
"id": "2e0328b6-44f2-4132-82ba-934a5ecfcab4",
"metadata": {},
"source": [
"Note that also ragged tensors can be used to input keras models which is much more effective and less costly in casting, but are only supported for tensorflow for now. If the tensor shape must not be changed for **JAX** also padded disjoint output can be generated with:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "acfb6ed4-7f61-4262-ba61-7b6f25d905f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Disjoint index:\n",
" tensor([[0, 1, 0, 0, 0, 3, 3, 4, 0],\n",
" [0, 1, 0, 0, 0, 3, 4, 3, 0]], device='cuda:0')\n",
"Node attributes:\n",
" tensor([[0., 0.],\n",
" [0., 1.],\n",
" [0., 0.],\n",
" [1., 0.],\n",
" [1., 1.]], device='cuda:0')\n",
"Batch ID nodes:\n",
" tensor([0, 1, 0, 2, 2], device='cuda:0')\n"
]
}
],
"source": [
"disjoint_tensors = CastBatchedIndicesToDisjoint(uses_mask=True, padded_disjoint=True)([nodes, edge_indices, node_mask, edge_mask])\n",
"node_attr, disjoint_index, graph_id_node, graph_id_edge, node_id, edge_id, node_count, edge_count = disjoint_tensors\n",
"print(\"Disjoint index:\\n\", disjoint_index)\n",
"print(\"Node attributes:\\n\", node_attr)\n",
"print(\"Batch ID nodes:\\n\", graph_id_node)"
]
},
{
"cell_type": "markdown",
"id": "27887067-6e43-4692-b84a-4f67403004d4",
"metadata": {},
"source": [
"Here nodes and edges with ID 0 are dummy nodes and can be later removed. They do message passing without interfering with the oder subgraphs.\n",
"However, using a padded batch is much more effective but requires a dataloader, i.e. `kgcnn.io` ."
]
},
{
"cell_type": "markdown",
"id": "76c77bc8-4b75-4111-866a-0fc92fd3a3fe",
"metadata": {},
"source": [
"#### Gather"
]
},
{
"cell_type": "markdown",
"id": "aa3f9864-8196-4f3e-be27-efe8e0506bec",
"metadata": {},
"source": [
"Selecting nodes via edge indices is simply realised by using `take` and carried out by the keras layer with some options:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fb74a325-2d64-4bc9-ad6c-c3709b79dbd0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([9, 4])\n",
"torch.Size([9, 2]) torch.Size([9, 2])\n"
]
}
],
"source": [
"from kgcnn.layers.gather import GatherNodes\n",
"nodes_per_edge = GatherNodes(split_indices=(0, 1), concat_axis=1)([node_attr, disjoint_index])\n",
"nodes_in, nodes_out = GatherNodes(split_indices=(0, 1), concat_axis=None)([node_attr, disjoint_index])\n",
"print(nodes_per_edge.shape)\n",
"print(nodes_in.shape, nodes_out.shape)"
]
},
{
"cell_type": "markdown",
"id": "b0a0d3d5-52a1-4bd0-a5bc-4d012a0ad3dd",
"metadata": {},
"source": [
"#### Convolution"
]
},
{
"cell_type": "markdown",
"id": "edf16154-67f8-4ee5-a001-2e6e1da9832d",
"metadata": {},
"source": [
"Convolution per node can now be done with for example a standard keras `Dense` layer."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4199c52c-75da-4f76-bfbb-05883b9f4e65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([9, 16])\n"
]
}
],
"source": [
"from keras.layers import Dense\n",
"edges_transformed = Dense(units=16, use_bias=True, activation=\"swish\")(nodes_per_edge)\n",
"print(edges_transformed.shape)"
]
},
{
"cell_type": "markdown",
"id": "2a84eb54-9533-410f-bc11-b43f12f13296",
"metadata": {},
"source": [
"#### Aggregation"
]
},
{
"cell_type": "markdown",
"id": "624d6186-7336-4d89-922e-ff067b2a8c8b",
"metadata": {},
"source": [
"Aggregation of edges per node can be done with scatter or segment operations. For backward compatibility and without any additional transformation\n",
"`AggregateLocalEdges` offers a direct approach. Additionally the node tensor has to be provided for the target shape (batch dimension) but can also be directly used to aggregate edges into."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b8cf3d59-de77-4b2f-85b6-f2b70ec6d92c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 16])\n"
]
}
],
"source": [
"from kgcnn.layers.aggr import AggregateLocalEdges\n",
"edges_aggregated = AggregateLocalEdges(pooling_method=\"scatter_sum\", pooling_index=0)([node_attr, edges_transformed, disjoint_index])\n",
"print(edges_aggregated.shape)"
]
},
{
"cell_type": "markdown",
"id": "ba98882a-2099-4ffc-8122-1ad052c3311a",
"metadata": {},
"source": [
"The basic aggregation layer design is at the moment:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1dd34fa8-abd9-4bdb-9be3-4dd67e0c5a3a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 16])\n"
]
}
],
"source": [
"from kgcnn.layers.aggr import Aggregate\n",
"edges_aggregated = Aggregate(pooling_method=\"scatter_sum\")([edges_transformed, disjoint_index[0], node_attr])\n",
"print(edges_aggregated.shape)"
]
},
{
"cell_type": "markdown",
"id": "93ba457f-8be8-4ccd-9ec0-9bdfef65b743",
"metadata": {},
"source": [
"#### Pooling"
]
},
{
"cell_type": "markdown",
"id": "85b4e60a-6f5a-4d6a-bc57-daef1b9db69f",
"metadata": {},
"source": [
"For graph level embedding nodes or edges are pooled per graph. Therefore the graph batch ID tensor is required and which can be\n",
"done with `Aggregate` in the same way. This is used in `kgcnn.layers.pooling` . For reference we can use the `node_count` tensor."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6091e92e-1762-4ef9-ad4f-45350376aa4e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,\n",
" -0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],\n",
" [-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,\n",
" 0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],\n",
" device='cuda:0', grad_fn=<ScatterReduceBackward0>)\n"
]
}
],
"source": [
"from kgcnn.layers.pooling import PoolingNodes\n",
"\n",
"graph_output = PoolingNodes()([node_count, edges_aggregated, graph_id_node])\n",
"print(graph_output)"
]
},
{
"cell_type": "markdown",
"id": "b98f31d6-2bbc-4e94-a917-21d8ec40e2c5",
"metadata": {},
"source": [
"since we used a padded disjoint representation the 0 graph was a dummy graph to deal with empty nodes.\n",
"It must be removed to get the final graph embeddings for the two samples in the batch:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "06ac06d3-1f1a-4e5d-8daa-10cfc217a371",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,\n",
" -0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],\n",
" [-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,\n",
" 0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],\n",
" device='cuda:0', grad_fn=<SliceBackward0>)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = graph_output[1:]\n",
"out"
]
},
{
Expand Down
Loading

0 comments on commit 762fd88

Please sign in to comment.