Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 20, 2023
1 parent 1361483 commit bd6200b
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 22 deletions.
11 changes: 6 additions & 5 deletions kgcnn/literature/PAiNN/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,30 +263,31 @@ def call(self, inputs, **kwargs):
inputs = ops.repeat(inputs, self.units, axis=-1)

if self.method == "zeros":
out = ops.zeros_like(inputs)
out = ops.zeros_like(inputs, dtype=self.dtype)
out = ops.expand_dims(out, axis=1)
out = ops.repeat(out, self.dim, axis=1)
elif self.method == "eps":
out = ops.zeros_like(inputs) + ks.backend.epsilon()
out = ops.zeros_like(inputs, dtype=self.dtype) + ks.backend.epsilon()
out = ops.expand_dims(out, axis=1)
out = ops.repeat(out, self.dim, axis=1)
elif self.method == "normal":
raise NotImplementedError()
elif self.method == "ones":
out = ops.ones_like(inputs)
out = ops.ones_like(inputs, dtype=self.dtype)
out = ops.expand_dims(out, axis=1)
out = ops.repeat(out, self.dim, axis=1)
elif self.method == "eye":
out = ops.eye(self.dim, ops.shape(inputs)[1], dtype=inputs.dtype)
out = ops.eye(self.dim, ops.shape(inputs)[1], dtype=self.dtype)
out = ops.expand_dims(out, axis=0)
out = ops.repeat(out, ops.shape(inputs)[0], axis=0)
elif self.method == "const":
out = ops.ones_like(inputs)*self.value
out = ops.ones_like(inputs, dtype=self.dtype)*self.value
out = ops.expand_dims(out, axis=1)
out = ops.repeat(out, self.dim, axis=1)
elif self.method == "node":
out = ops.expand_dims(inputs, axis=1)
out = ops.repeat(out, self.dim, axis=1)
out = ops.cast(out, dtype=self.dtype)
else:
raise ValueError("Unknown initialization method %s" % self.method)
return out
Expand Down
92 changes: 92 additions & 0 deletions training/hyper/hyper_md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,96 @@
"kgcnn_version": "4.0.0"
}
},
"PAiNN.EnergyForceModel": {
"model": {
"class_name": "EnergyForceModel",
"module_name": "kgcnn.models.force",
"config": {
"name": "PAiNN",
"nested_model_config": True,
"output_to_tensor": False,
"output_squeeze_states": True,
"coordinate_input": 1,
"inputs": [
{"shape": [None], "name": "atomic_number", "dtype": "int32"},
{"shape": [None, 3], "name": "node_coordinates", "dtype": "float32"},
{"shape": [None, 2], "name": "range_indices", "dtype": "int64"},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_ranges", "dtype": "int64"}
],
"model_energy": {
"class_name": "make_model",
"module_name": "kgcnn.literature.PAiNN",
"config": {
"name": "PAiNNEnergy",
"inputs": [
{"shape": [None], "name": "atomic_number", "dtype": "int32"},
{"shape": [None, 3], "name": "node_coordinates", "dtype": "float32"},
{"shape": [None, 2], "name": "range_indices", "dtype": "int64"},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_ranges", "dtype": "int64"}
],
"input_embedding": None,
"input_node_embedding": {"input_dim": 95, "output_dim": 128},
"equiv_initialize_kwargs": {"dim": 3, "method": "eps"},
"bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5},
"pooling_args": {"pooling_method": "scatter_sum"},
"conv_args": {"units": 128, "cutoff": None},
"update_args": {"units": 128}, "depth": 3, "verbose": 10,
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]},
}
},
"outputs": {"energy": {"name": "energy", "shape": (1,)},
"force": {"name": "force", "shape": (None, 3)}}
}
},
"training": {
"fit": {
"batch_size": 32, "epochs": 1000, "validation_freq": 1, "verbose": 2,
"callbacks": []
},
"compile": {
"optimizer": {
"class_name": "Adam", "config": {
"learning_rate": {
"class_name": "kgcnn>LinearWarmupExponentialDecay", "config": {
"learning_rate": 0.001, "warmup_steps": 150.0, "decay_steps": 20000.0,
"decay_rate": 0.01
}
}, "amsgrad": True, "use_ema": True
}
},
"loss_weights": {"energy": 0.02, "force": 0.98}
},
"scaler": {"class_name": "EnergyForceExtensiveLabelScaler",
"config": {"standardize_scale": True}},
},
"data": {
},
"dataset": {
"class_name": "MD17Dataset",
"module_name": "kgcnn.data.datasets.MD17Dataset",
"config": {
# toluene_ccsd_t, aspirin_ccsd, malonaldehyde_ccsd_t, benzene_ccsd_t, ethanol_ccsd_t
"trajectory_name": trajectory_name
},
"methods": [
{"rename_property_on_graphs": {"old_property_name": "E", "new_property_name": "energy"}},
{"rename_property_on_graphs": {"old_property_name": "F", "new_property_name": "force"}},
{"rename_property_on_graphs": {"old_property_name": "z", "new_property_name": "atomic_number"}},
{"rename_property_on_graphs": {"old_property_name": "R", "new_property_name": "node_coordinates"}},
{"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000,
"node_coordinates": "node_coordinates"}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "atomic_number",
"total_nodes": "total_nodes"}},
]
},
"info": {
"postfix": "",
"postfix_file": "_"+trajectory_name,
"kgcnn_version": "4.0.0"
}
},
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
OS: posix_linux
backend: torch
cuda_available: 'True'
data_unit: eV
date_time: '2023-11-17 13:52:35'
device_id: '[0]'
device_memory: '[{''allocated'': 0.0, ''cached'': 3.3}]'
device_name: '[''NVIDIA A100 80GB PCIe'']'
epochs:
- 800
- 800
- 800
- 800
- 800
execute_folds:
- 4
kgcnn_version: 4.0.0
learning_rate:
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
loss:
- 0.7510426640510559
- 0.7514869570732117
- 0.0296995397657156
- 0.09922590851783752
- 0.752254843711853
max_learning_rate:
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
max_loss:
- 152.89744567871094
- 2.6082866191864014
- 1.3898086547851562
- 43.36865234375
- 70.06098175048828
max_scaled_mean_absolute_error:
- 244.31700134277344
- 4.172955513000488
- 2.223109722137451
- 69.44451141357422
- 112.0015640258789
max_scaled_root_mean_squared_error:
- 25650.35546875
- 865.424072265625
- 533.6707153320312
- 19664.10546875
- 20820.58203125
max_val_loss:
- 0.7587342262268066
- 1.2743823528289795
- 0.7526670098304749
- 0.7512235045433044
- 2.563725709915161
max_val_scaled_mean_absolute_error:
- 1.210221767425537
- 2.039233446121216
- 1.20370614528656
- 1.2025171518325806
- 4.099420070648193
max_val_scaled_root_mean_squared_error:
- 1.9370406866073608
- 110.60933685302734
- 1.8700799942016602
- 1.881477952003479
- 422.8948669433594
min_learning_rate:
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
min_loss:
- 0.1837150901556015
- 0.17997939884662628
- 0.0296995397657156
- 0.09922590851783752
- 0.17990800738334656
min_scaled_mean_absolute_error:
- 0.2933097779750824
- 0.28793007135391235
- 0.04748707264661789
- 0.15884126722812653
- 0.2875083088874817
min_scaled_root_mean_squared_error:
- 0.6002561450004578
- 0.5936135053634644
- 0.204745814204216
- 0.3796391487121582
- 0.593903124332428
min_val_loss:
- 0.20087158679962158
- 0.19898207485675812
- 0.1612144559621811
- 0.174364373087883
- 0.198722705245018
min_val_scaled_mean_absolute_error:
- 0.32050538063049316
- 0.31839296221733093
- 0.25761139392852783
- 0.2782728374004364
- 0.31677064299583435
min_val_scaled_root_mean_squared_error:
- 0.6462162733078003
- 0.6535683274269104
- 0.5569960474967957
- 0.5836672186851501
- 0.6555573344230652
model_class: make_crystal_model
model_name: Schnet
model_version: '2023-09-07'
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 1.1996724605560303
- 1.201776385307312
- 0.04748707264661789
- 0.15884126722812653
- 1.202290654182434
scaled_root_mean_squared_error:
- 1.8552837371826172
- 1.8511769771575928
- 0.204745814204216
- 0.38051795959472656
- 1.8522521257400513
seed: 42
time_list:
- '17:33:58.401059'
- '17:09:57.241122'
- '16:43:25.549148'
- '16:16:21.689500'
- '16:59:59.434099'
val_loss:
- 0.756482720375061
- 0.7494837641716003
- 0.16426515579223633
- 0.17509213089942932
- 1.9805660247802734
val_scaled_mean_absolute_error:
- 1.2068122625350952
- 1.19844651222229
- 0.26236721873283386
- 0.2794942855834961
- 3.16611647605896
val_scaled_root_mean_squared_error:
- 1.8646117448806763
- 1.852222204208374
- 0.5867417454719543
- 0.5928786396980286
- 286.9599609375
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model": {"module_name": "kgcnn.literature.Schnet", "class_name": "make_crystal_model", "config": {"name": "Schnet", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int32"}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32"}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64"}, {"shape": [null, 3], "name": "range_image", "dtype": "int64"}, {"shape": [3, 3], "name": "graph_lattice", "dtype": "float32"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_ranges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": false}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "interaction_args": {"units": 128, "use_bias": true, "activation": "kgcnn>shifted_softplus", "cfconv_pool": "scatter_sum"}, "node_pooling_args": {"pooling_method": "scatter_mean"}, "depth": 4, "gauss_args": {"bins": 25, "distance": 5, "offset": 0.0, "sigma": 0.4}, "verbose": 10, "last_mlp": {"use_bias": [true, true, true], "units": [128, 64, 1], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus", "linear"]}, "output_embedding": "graph", "use_output_mlp": false, "output_mlp": null}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0005}}, "loss": "mean_absolute_error"}, "scaler": {"class_name": "StandardLabelScaler", "module_name": "kgcnn.data.transform.scaler.standard", "config": {"with_std": true, "with_mean": true, "copy": true}}, "multi_target_indices": null}, "data": {"data_unit": "eV"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "MatProjectGapDataset", "module_name": "kgcnn.data.datasets.MatProjectGapDataset", "config": {}, "methods": [{"map_list": {"method": "set_range_periodic", "max_distance": 5, "max_neighbours": 32}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices", "count_nodes": "node_number", "total_nodes": "total_nodes"}}]}}
8 changes: 8 additions & 0 deletions training/results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ Materials Project dataset from Matbench with 4764 crystal structures and their c
|:--------------------------|:--------|---------:|:-----------------------|:-----------------------|
| Schnet.make_crystal_model | 4.0.0 | 800 | **0.3180 ± 0.0359** | **1.8509 ± 0.5854** |

#### MatProjectGapDataset

Materials Project dataset from Matbench with 106113 crystal structures and their band gap as calculated by PBE DFT from the Materials Project, in eV. We use a random 5-fold cross-validation.

| model | kgcnn | epochs | MAE [eV] | RMSE [eV] |
|:--------------------------|:--------|---------:|:-----------------------|:--------------------------|
| Schnet.make_crystal_model | 4.0.0 | 800 | **1.2226 ± 1.0573** | **58.3713 ± 114.2957** |

#### MatProjectIsMetalDataset

Materials Project dataset from Matbench with 106113 crystal structures and their corresponding Metallicity determined with pymatgen. 1 if the compound is a metal, 0 if the compound is not a metal. We use a random 5-fold cross-validation.
Expand Down
30 changes: 15 additions & 15 deletions training/results/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,21 @@
# "is_min_max": True}
# ]
# },
# "MatProjectGapDataset": {
# "general_info": [
# "Materials Project dataset from Matbench with 106113 crystal structures ",
# "and their band gap as calculated by PBE DFT from the Materials Project, in eV. ",
# "We use a random 5-fold cross-validation. "
# ],
# "targets": [
# {"metric": "val_scaled_mean_absolute_error", "name": "MAE [eV]", "find_best": "min"},
# {"metric": "val_scaled_root_mean_squared_error", "name": "RMSE [eV]", "find_best": "min"},
# {"metric": "min_val_scaled_mean_absolute_error", "name": "*Min. MAE*", "find_best": "min",
# "is_min_max": True},
# {"metric": "min_val_scaled_root_mean_squared_error", "name": "*Min. RMSE*", "find_best": "min",
# "is_min_max": True}
# ]
# },
"MatProjectGapDataset": {
"general_info": [
"Materials Project dataset from Matbench with 106113 crystal structures ",
"and their band gap as calculated by PBE DFT from the Materials Project, in eV. ",
"We use a random 5-fold cross-validation. "
],
"targets": [
{"metric": "val_scaled_mean_absolute_error", "name": "MAE [eV]", "find_best": "min"},
{"metric": "val_scaled_root_mean_squared_error", "name": "RMSE [eV]", "find_best": "min"},
{"metric": "min_val_scaled_mean_absolute_error", "name": "*Min. MAE*", "find_best": "min",
"is_min_max": True},
{"metric": "min_val_scaled_root_mean_squared_error", "name": "*Min. RMSE*", "find_best": "min",
"is_min_max": True}
]
},
"MatProjectIsMetalDataset": {
"general_info": [
"Materials Project dataset from Matbench with 106113 crystal structures ",
Expand Down
4 changes: 2 additions & 2 deletions training/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
# Input arguments from command line.
parser = argparse.ArgumentParser(description='Train a GNN on an Energy-Force Dataset.')
parser.add_argument("--hyper", required=False, help="Filepath to hyper-parameter config file (.py or .json).",
default="hyper/hyper_iso17.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="Schnet.EnergyForceModel")
default="hyper/hyper_md17.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="PAiNN.EnergyForceModel")
parser.add_argument("--model", required=False, help="Graph model to train.", default=None)
parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None)
parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)
Expand Down

0 comments on commit bd6200b

Please sign in to comment.