Skip to content

Commit

Permalink
added new loader and small minor bugs in hyperparameter. Added traini…
Browse files Browse the repository at this point in the history
…ng results.
  • Loading branch information
PatReis committed Dec 15, 2023
1 parent 1c2080c commit 3826a67
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 7 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ original implementations (with proper licencing).
* **[MoGAT](kgcnn/literature/MoGAT)**: [Multi-order graph attention network for water solubility prediction and interpretation](https://www.nature.com/articles/s41598-022-25701-5) by Lee et al. (2023)
* **[MXMNet](kgcnn/literature/MXMNet)**: [Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures](https://arxiv.org/abs/2011.07457) by Zhang et al. (2020)
* **[NMPN](kgcnn/literature/NMPN)**: [Neural Message Passing for Quantum Chemistry](http://arxiv.org/abs/1704.01212) by Gilmer et al. (2017)
* **[Unet](kgcnn/literature/Unet)**: [Graph U-Nets](http://proceedings.mlr.press/v97/gao19a/gao19a.pdf) by H. Gao and S. Ji (2019)
* **[PAiNN](kgcnn/literature/PAiNN)**: [Equivariant message passing for the prediction of tensorial properties and molecular spectra](https://arxiv.org/pdf/2102.03150.pdf) by Schütt et al. (2020)
* **[RGCN](kgcnn/literature/RGCN)**: [Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103) by Schlichtkrull et al. (2017)
* **[rGIN](kgcnn/literature/rGIN)** [Random Features Strengthen Graph Neural Networks](https://arxiv.org/abs/2002.03155) by Sato et al. (2020)
Expand Down
3 changes: 2 additions & 1 deletion kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def rename_property_on_graphs(self, old_property_name: str, new_property_name: s
get = obtain_property

def tf_disjoint_data_generator(self, inputs, outputs, **kwargs):
module_logger.info("Dataloader is experimental and not fully tested nor stable.")
assert isinstance(inputs, list), "Dictionary input is not yet implemented"
module_logger.info("Dataloader is experimental and not fully tested or stable.")
return experimental_tf_disjoint_list_generator(self, inputs=inputs, outputs=outputs, **kwargs)


Expand Down
95 changes: 95 additions & 0 deletions kgcnn/io/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras as ks
from typing import Union
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -90,3 +91,97 @@ def generator():
)

return data_loader


def tf_disjoint_list_generator(
graphs,
inputs: list,
outputs: list,
assignment_to_id: list = None,
assignment_of_indices: list = None,
flag_batch_id: list = None,
flag_count: list = None,
flag_subgraph_id: list = None,
batch_size=32,
shuffle=True
):

def generator():
dataset_size = len(graphs)
data_index = np.arange(dataset_size)
num_inputs = len(inputs)
all_flags = [flag_batch_id, flag_count, flag_subgraph_id]
is_attributes = [True if all([x[i] is not None for x in all_flags]) else False for i in range(num_inputs)]
where_batch = []
where_subgraph= []
where_count = []
num_attributes = sum(is_attributes)

if shuffle:
np.random.shuffle(data_index)

for batch_index in range(0, dataset_size, batch_size):
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]

out = [None for _ in range(num_attributes)]
out_counts = [None for _ in range(num_attributes)]

for i in range(num_inputs):
if not is_attributes[i]:
continue

array_list = [x[inputs[i]["name"]] for x in graphs_batch]
if assignment_to_id[i] is None:
out[i] = np.array(array_list, dtype=inputs[i]["dtype"])
else:
out[i] = np.concatenate(array_list, axis=0)
counts = np.array([len(x) for x in array_list], dtype="int64")
out_counts[i] = counts
ids = assignment_to_id[i]
if out[where_count[ids]] is not None:
out[where_count[ids]] = counts
if out[where_batch[ids]] is not None:
out[where_batch[ids]] = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=counts)
if out[where_subgraph[ids]] is not None:
out[where_subgraph[ids]] = np.concatenate([np.arange(x, dtype="int64") for x in counts], axis=0)

# Indices
for i in range(num_inputs):
if assignment_of_indices[i] is not None:
edge_indices_flatten = out[i]
count_nodes = out_counts[assignment_of_indices[i]]
count_edges = out_counts[i]
node_splits = np.pad(np.cumsum(count_nodes), [[1, 0]])
offset_edge_indices = np.expand_dims(np.repeat(node_splits[:-1], count_edges), axis=-1)
disjoint_indices = edge_indices_flatten + offset_edge_indices
disjoint_indices = np.transpose(disjoint_indices)
out[i] = disjoint_indices

if isinstance(outputs, list):
out_y = []
for k in range(len(outputs)):
array_list = [x[outputs[k]["name"]] for x in graphs_batch]
out_y.append(np.array(array_list, dtype=outputs[k]["dtype"]))
else:
out_y = np.array(
[x[outputs["name"]] for x in graphs_batch], dtype=outputs["dtype"])

yield tuple(out), out_y

input_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in inputs])

if isinstance(outputs, list):
output_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in outputs])
else:
output_spec = tf.TensorSpec(shape=tuple([None] + list(outputs["shape"])), dtype=outputs["dtype"])

data_loader = tf.data.Dataset.from_generator(
generator,
output_signature=(
input_spec,
output_spec
)
)

return data_loader
6 changes: 3 additions & 3 deletions kgcnn/literature/NMPN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ._make import make_model, model_default
# from ._make import make_crystal_model, model_crystal_default
from ._make import make_crystal_model, model_crystal_default


__all__ = [
"make_model",
"model_default",
# "make_crystal_model",
# "model_crystal_default"
"make_crystal_model",
"model_crystal_default"
]
2 changes: 1 addition & 1 deletion training/hyper/hyper_mp_jdft2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
"input_tensor_type": "ragged",
'input_embedding': None,
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": {"input_dim": 100, "output_dim": 64},
# "input_edge_embedding": {"input_dim": 100, "output_dim": 64},
"make_distance": True, "expand_distance": True,
'gauss_args': {"bins": 25, "distance": 5, "offset": 0.0, "sigma": 0.4},
'meg_block_args': {'node_embed': [64, 32, 32], 'edge_embed': [64, 32, 32],
Expand Down
2 changes: 1 addition & 1 deletion training/hyper/hyper_qm9_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"input_tensor_type": "ragged",
"input_embedding": None,
"input_node_embedding": {"input_dim": 10, "output_dim": 16},
"input_edge_embedding": {"input_dim": 100, "output_dim": 64},
# "input_edge_embedding": {"input_dim": 100, "output_dim": 64},
"gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4},
"meg_block_args": {"node_embed": [64, 32, 32], "edge_embed": [64, 32, 32],
"env_embed": [64, 32, 32], "activation": "kgcnn>softplus2"},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
OS: posix_linux
backend: tensorflow
cuda_available: 'True'
data_unit: meV/atom
date_time: '2023-12-15 14:11:18'
device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU''), LogicalDevice(name=''/device:GPU:0'',
device_type=''GPU'')]'
device_memory: '[]'
device_name: '[{}, {''compute_capability'': (8, 0), ''device_name'': ''NVIDIA A100
80GB PCIe''}]'
epochs:
- 1000
- 1000
- 1000
- 1000
- 1000
execute_folds: null
kgcnn_version: 4.0.0
learning_rate:
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
loss:
- 0.0164102204144001
- 0.049058035016059875
- 0.019465263932943344
- 0.04565507546067238
- 0.018842527642846107
max_learning_rate:
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
max_loss:
- 0.4873706102371216
- 0.4823596477508545
- 0.4699196219444275
- 0.44965362548828125
- 0.46280232071876526
max_scaled_mean_absolute_error:
- 70.67137908935547
- 67.2155532836914
- 64.04151916503906
- 57.88063049316406
- 55.339359283447266
max_scaled_root_mean_squared_error:
- 152.97747802734375
- 145.45938110351562
- 142.39430236816406
- 133.3281707763672
- 124.1724624633789
max_val_loss:
- 0.26619401574134827
- 0.3685653507709503
- 0.40042510628700256
- 0.6030166745185852
- 0.6730947494506836
max_val_scaled_mean_absolute_error:
- 38.59397888183594
- 51.37816619873047
- 54.60137939453125
- 77.64346313476562
- 80.49958038330078
max_val_scaled_root_mean_squared_error:
- 83.75748443603516
- 116.78331756591797
- 139.4757537841797
- 208.695556640625
- 203.95716857910156
min_learning_rate:
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
min_loss:
- 0.0164102204144001
- 0.049058035016059875
- 0.019465263932943344
- 0.04565083235502243
- 0.018809372559189796
min_scaled_mean_absolute_error:
- 2.3664445877075195
- 6.843836784362793
- 2.6597533226013184
- 5.826504707336426
- 2.251401424407959
min_scaled_root_mean_squared_error:
- 12.384324073791504
- 72.17166900634766
- 12.613103866577148
- 70.442138671875
- 11.607565879821777
min_val_loss:
- 0.1912095546722412
- 0.27041757106781006
- 0.34239882230758667
- 0.4209592342376709
- 0.5613946318626404
min_val_scaled_mean_absolute_error:
- 27.72239875793457
- 37.69633102416992
- 46.68900680541992
- 54.20203399658203
- 67.14065551757812
min_val_scaled_root_mean_squared_error:
- 50.73049545288086
- 101.22430419921875
- 123.49358367919922
- 149.4101104736328
- 180.5041961669922
model_class: make_crystal_model
model_name: CGCNN
model_version: '2023-11-28'
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 2.3664445877075195
- 6.843836784362793
- 2.6597533226013184
- 5.827065944671631
- 2.255373477935791
scaled_root_mean_squared_error:
- 12.439318656921387
- 72.18438720703125
- 12.624910354614258
- 70.45165252685547
- 11.619587898254395
seed: 42
time_list:
- '0:01:59.293506'
- '0:01:55.045981'
- '0:01:54.909115'
- '0:01:54.808906'
- '0:01:57.223260'
val_loss:
- 0.25275886058807373
- 0.29345396161079407
- 0.3919019401073456
- 0.5991652011871338
- 0.6718166470527649
val_scaled_mean_absolute_error:
- 36.646087646484375
- 40.90761184692383
- 53.439170837402344
- 77.14754486083984
- 80.34671783447266
val_scaled_root_mean_squared_error:
- 79.33216857910156
- 105.15802001953125
- 138.79078674316406
- 180.42710876464844
- 199.3755645751953
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model": {"class_name": "make_crystal_model", "module_name": "kgcnn.literature.CGCNN", "config": {"name": "CGCNN", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "node_frac_coordinates", "dtype": "float64", "ragged": true}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "range_image", "dtype": "float32", "ragged": true}, {"shape": [3, 3], "name": "graph_lattice", "dtype": "float64", "ragged": false}], "input_tensor_type": "ragged", "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "representation": "unit", "expand_distance": true, "make_distances": true, "gauss_args": {"bins": 60, "distance": 6, "offset": 0.0, "sigma": 0.4}, "conv_layer_args": {"units": 128, "activation_s": "kgcnn>shifted_softplus", "activation_out": "kgcnn>shifted_softplus", "batch_normalization": true}, "node_pooling_args": {"pooling_method": "scatter_mean"}, "depth": 4, "output_mlp": {"use_bias": [true, true, false], "units": [128, 64, 1], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus", "linear"]}}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "fit": {"batch_size": 128, "epochs": 1000, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 1e-05, "epo_min": 500, "epo": 1000, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error"}, "scaler": {"class_name": "StandardLabelScaler", "module_name": "kgcnn.data.transform.scaler.standard", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "data": {"data_unit": "meV/atom"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "MatProjectJdft2dDataset", "module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset", "config": {}, "methods": [{"map_list": {"method": "set_range_periodic", "max_distance": 6.0}}]}}
Loading

0 comments on commit 3826a67

Please sign in to comment.