Skip to content

Commit

Permalink
tested and updated dataloader. Fixed minor bugs. Added more test trai…
Browse files Browse the repository at this point in the history
…ning.
  • Loading branch information
PatReis committed Dec 18, 2023
1 parent 3826a67 commit 6544fcd
Show file tree
Hide file tree
Showing 16 changed files with 403 additions and 141 deletions.
4 changes: 2 additions & 2 deletions kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import os
from sklearn.model_selection import KFold
from kgcnn.io.loader import experimental_tf_disjoint_list_generator
from kgcnn.io.loader import tf_disjoint_list_generator
# import typing as t
from typing import Union, List, Callable, Dict, Optional
# from collections.abc import MutableSequence
Expand Down Expand Up @@ -332,7 +332,7 @@ def rename_property_on_graphs(self, old_property_name: str, new_property_name: s
def tf_disjoint_data_generator(self, inputs, outputs, **kwargs):
assert isinstance(inputs, list), "Dictionary input is not yet implemented"
module_logger.info("Dataloader is experimental and not fully tested or stable.")
return experimental_tf_disjoint_list_generator(self, inputs=inputs, outputs=outputs, **kwargs)
return tf_disjoint_list_generator(self, inputs=inputs, outputs=outputs, **kwargs)


class MemoryGraphDataset(MemoryGraphList):
Expand Down
157 changes: 46 additions & 111 deletions kgcnn/io/loader.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,66 @@
import keras as ks
from typing import Union
import numpy as np
from numpy.random import Generator, PCG64
import tensorflow as tf


def experimental_tf_disjoint_list_generator(graphs,
inputs,
outputs,
has_nodes=True,
has_edges=True,
has_graph_state=False,
batch_size=32,
shuffle=True):
def generator():
dataset_size = len(graphs)
data_index = np.arange(dataset_size)

if shuffle:
np.random.shuffle(data_index)

for batch_index in range(0, dataset_size, batch_size):
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]

batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = [None for _ in range(6)]
out = []
inputs_pos = 0
for j in range(int(has_nodes)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.concatenate(array_list, axis=0))
inputs_pos += 1
if j == 0:
count_nodes = np.array([len(x) for x in array_list], dtype="int64")
batch_id_node = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=count_nodes)
node_id = np.concatenate([np.arange(x, dtype="int64") for x in count_nodes], axis=0)

for j in range(int(has_edges)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.concatenate(array_list, axis=0, dtype=inputs[inputs_pos]["dtype"]))
inputs_pos += 1

for j in range(int(has_graph_state)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.array(array_list, dtype=inputs[inputs_pos]["dtype"]))
inputs_pos += 1

# Indices
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
count_edges = np.array([len(x) for x in array_list], dtype="int64")
batch_id_edge = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=count_edges)
edge_id = np.concatenate([np.arange(x, dtype="int64") for x in count_edges], axis=0)
edge_indices_flatten = np.concatenate(array_list, axis=0)

node_splits = np.pad(np.cumsum(count_nodes), [[1, 0]])
offset_edge_indices = np.expand_dims(np.repeat(node_splits[:-1], count_edges), axis=-1)
disjoint_indices = edge_indices_flatten + offset_edge_indices
disjoint_indices = np.transpose(disjoint_indices)
out.append(disjoint_indices)

out = out + [batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges]

if isinstance(outputs, list):
out_y = []
for k in range(len(outputs)):
array_list = [x[outputs[k]["name"]] for x in graphs_batch]
out_y.append(np.array(array_list, dtype=outputs[k]["dtype"]))
elif isinstance(outputs, dict):
out_y = np.array(
[x[outputs["name"]] for x in graphs_batch], dtype=outputs["dtype"])
else:
raise ValueError()

yield tuple(out), out_y

input_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in inputs])

if isinstance(outputs, list):
output_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in outputs])
elif isinstance(outputs, dict):
output_spec = tf.TensorSpec(shape=tuple([None] + list(outputs["shape"])), dtype=outputs["dtype"])
else:
raise ValueError()

data_loader = tf.data.Dataset.from_generator(
generator,
output_signature=(
input_spec,
output_spec
)
)

return data_loader


def tf_disjoint_list_generator(
graphs,
inputs: list,
outputs: list,
assignment_to_id: list = None,
assignment_of_indices: list = None,
flag_batch_id: list = None,
flag_count: list = None,
flag_subgraph_id: list = None,
pos_batch_id: list = None,
pos_subgraph_id: list = None,
pos_count: list = None,
batch_size=32,
shuffle=True
padded_disjoint=False,
epochs=None,
shuffle=True,
seed=42
):
dataset_size = len(graphs)
data_index = np.arange(dataset_size)
num_inputs = len(inputs)

if len(assignment_to_id) < num_inputs:
assignment_to_id = assignment_to_id + [None for _ in range(num_inputs-len(assignment_to_id))]
if len(assignment_of_indices) < num_inputs:
assignment_of_indices = assignment_of_indices + [None for _ in range(num_inputs-len(assignment_of_indices))]

flag_batch_id = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_batch_id):
flag_batch_id[x] = i

flag_count = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_count):
flag_count[x] = i

flag_subgraph_id = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_subgraph_id):
flag_subgraph_id[x] = i

all_flags = [flag_batch_id, flag_count, flag_subgraph_id]
is_attributes = [True if all([x[i] is None for x in all_flags]) else False for i in range(num_inputs)]

if padded_disjoint:
if epochs is None:
raise ValueError("Requires number of epochs if `padded_disjoint=True` .")

rng = Generator(PCG64(seed=seed))

def generator():
dataset_size = len(graphs)
data_index = np.arange(dataset_size)
num_inputs = len(inputs)
all_flags = [flag_batch_id, flag_count, flag_subgraph_id]
is_attributes = [True if all([x[i] is not None for x in all_flags]) else False for i in range(num_inputs)]
where_batch = []
where_subgraph= []
where_count = []
num_attributes = sum(is_attributes)

if shuffle:
np.random.shuffle(data_index)
rng.shuffle(data_index)

for batch_index in range(0, dataset_size, batch_size):
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]

out = [None for _ in range(num_attributes)]
out_counts = [None for _ in range(num_attributes)]
out = [None for _ in range(num_inputs)]
out_counts = [None for _ in range(num_inputs)]

for i in range(num_inputs):
if not is_attributes[i]:
Expand All @@ -139,12 +74,12 @@ def generator():
counts = np.array([len(x) for x in array_list], dtype="int64")
out_counts[i] = counts
ids = assignment_to_id[i]
if out[where_count[ids]] is not None:
out[where_count[ids]] = counts
if out[where_batch[ids]] is not None:
out[where_batch[ids]] = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=counts)
if out[where_subgraph[ids]] is not None:
out[where_subgraph[ids]] = np.concatenate([np.arange(x, dtype="int64") for x in counts], axis=0)
if out[pos_count[ids]] is None:
out[pos_count[ids]] = counts
if out[pos_batch_id[ids]] is None:
out[pos_batch_id[ids]] = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=counts)
if out[pos_subgraph_id[ids]] is None:
out[pos_subgraph_id[ids]] = np.concatenate([np.arange(x, dtype="int64") for x in counts], axis=0)

# Indices
for i in range(num_inputs):
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def call(self, inputs, **kwargs):
"""Forward pass.
Args:
inputs (list): `[position, edge_image, lattice, num_edges]`
inputs (list): `[position, edge_image, lattice, batch_id_edge]`
- position (Tensor): Positions of shape `(M, 3)`
- edge_image (Tensor): Position in which image to shift of shape `(M, 3)`
Expand Down
1 change: 1 addition & 0 deletions kgcnn/literature/CMPNN/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, units, static_output_shape=None,
**kwargs):
super(PoolingNodesGRU, self).__init__(**kwargs)
self.units = units
self.static_output_shape = static_output_shape
self.cast_layer = CastDisjointToBatchedAttributes(
static_output_shape=static_output_shape, return_mask=True)
self.gru = GRU(
Expand Down
8 changes: 6 additions & 2 deletions kgcnn/literature/GIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def set_scale(*args, **kwargs):

make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__)


model_default_edge = {
"name": "GINE",
"inputs": [
Expand Down Expand Up @@ -244,7 +243,12 @@ def make_model_edge(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

disjoint_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs

Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/Megnet/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def model_disjoint_crystal(
# Edge distance as Gauss-Basis
if make_distance:
pos1, pos2 = NodePosition()([x, edi])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice, batch_id_edge])
ep = NodeDistanceEuclidean()([pos1, pos2])
else:
ep = x
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/NMPN/_make.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import keras as ks
from kgcnn.layers.scale import get as get_scaler
from ._model import model_disjoint
from ._model import model_disjoint, model_disjoint_crystal
from kgcnn.layers.modules import Input
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from kgcnn.models.utils import update_model_kwargs
Expand Down Expand Up @@ -309,7 +309,7 @@ def make_crystal_model(inputs: list = None,

n, x, d_indices, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj

out = model_disjoint(
out = model_disjoint_crystal(
[n, x, d_indices, img, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges],
use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False,
use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False,
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/NMPN/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def model_disjoint_crystal(inputs,
if make_distance:
x = ed
pos1, pos2 = NodePosition()([x, disjoint_indices])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice, batch_id_edge])
ed = NodeDistanceEuclidean()([pos1, pos2])

if expand_distance:
Expand Down
3 changes: 3 additions & 0 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def template_cast_output(model_outputs,

# Output embedding choice
if output_embedding == 'graph':
# Here we could also modify the behaviour for direct disjoint output to not remove padded ones,
# in case also the output is padded.
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
elif output_embedding == 'node':
if output_tensor_type in ["padded", "masked"]:
Expand Down Expand Up @@ -85,6 +87,7 @@ def template_cast_list_input(model_inputs,
:obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,...]` .
Note that in place of nodes or edges also more than one tensor can be provided, depending on the model, for example
:obj:`[nodes_1, nodes_2, edges_1, edges_2, edge_indices, ...]` .
However, for future models we intend to used named inputs rather than a list that is sensible to ordering.
Whether to use mask or length tensor for padded as well as further parameter of casting has to be set with
(dict) :obj:`cast_disjoint_kwargs` .
Expand Down
31 changes: 20 additions & 11 deletions notebooks/tutorial_model_loading_options.ipynb

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions training/hyper/hyper_qm7.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2,
"callbacks": [
Expand Down Expand Up @@ -169,7 +168,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 32, "epochs": 500, "validation_freq": 10, "verbose": 2,
"callbacks": [
Expand Down Expand Up @@ -232,7 +230,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 32, "epochs": 872, "validation_freq": 10, "verbose": 2, "callbacks": []
},
Expand Down Expand Up @@ -304,7 +301,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 10, "epochs": 872, "validation_freq": 10, "verbose": 2, "callbacks": []
},
Expand Down Expand Up @@ -388,7 +384,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 128, "epochs": 900, "validation_freq": 10, "verbose": 2,
"callbacks": [
Expand Down Expand Up @@ -473,7 +468,6 @@
}
},
"training": {
"cross_validation": None,
"fit": {
"batch_size": 64, "epochs": 800, "validation_freq": 10, "verbose": 2,
"callbacks": [
Expand Down
Loading

0 comments on commit 6544fcd

Please sign in to comment.