Skip to content

Commit

Permalink
fix hyperparemeter for benchmark training and added benchmark results.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 12, 2024
1 parent ca8063a commit 307a42f
Show file tree
Hide file tree
Showing 6 changed files with 507 additions and 2 deletions.
2 changes: 1 addition & 1 deletion training/hyper/hyper_qm9_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@
"inputs": [
{"shape": (None, 15), "name": "node_attributes", "dtype": "float32", "ragged": True},
{"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True},
{"shape": (None, 1), "name": "range_attributes", "dtype": "int64", "ragged": True},
{"shape": (None, 1), "name": "range_attributes", "dtype": "float32", "ragged": True},
{"shape": (None, 2), "name": "range_indices", "dtype": "int64", "ragged": True}
],
"input_tensor_type": "ragged",
Expand Down
2 changes: 1 addition & 1 deletion training/hyper/hyper_qm9_orbitals.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@
"inputs": [
{"shape": (None, 15), "name": "node_attributes", "dtype": "float32", "ragged": True},
{"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True},
{"shape": (None, 1), "name": "range_attributes", "dtype": "int64", "ragged": True},
{"shape": (None, 1), "name": "range_attributes", "dtype": "float32", "ragged": True},
{"shape": (None, 2), "name": "range_indices", "dtype": "int64", "ragged": True}
],
"input_tensor_type": "ragged",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
OS: posix_linux
backend: tensorflow
cuda_available: 'True'
data_unit: eV/atom
date_time: '2024-02-12 00:09:57'
device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU''), LogicalDevice(name=''/device:GPU:0'',
device_type=''GPU'')]'
device_memory: '[]'
device_name: '[{}, {''compute_capability'': (8, 0), ''device_name'': ''NVIDIA A100
80GB PCIe''}]'
epochs:
- 1000
- 1000
- 1000
- 1000
- 1000
execute_folds:
- 4
kgcnn_version: 4.0.0
learning_rate:
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
loss:
- 0.0026080473326146603
- 0.0025584164541214705
- 0.002506435848772526
- 0.0025094763841480017
- 0.002450938569381833
max_learning_rate:
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
max_loss:
- 0.2644650340080261
- 0.2624846398830414
- 0.2626839280128479
- 0.269146591424942
- 0.2654799222946167
max_scaled_mean_absolute_error:
- 0.30768436193466187
- 0.3053379952907562
- 0.30544695258140564
- 0.3133404552936554
- 0.3085813522338867
max_scaled_root_mean_squared_error:
- 0.48864248394966125
- 0.4834892451763153
- 0.48622262477874756
- 0.49567335844039917
- 0.489039808511734
max_val_loss:
- 0.07701177150011063
- 0.07798759639263153
- 0.07978218048810959
- 0.07236430048942566
- 0.07760155946016312
max_val_scaled_mean_absolute_error:
- 0.0895666554570198
- 0.09075972437858582
- 0.09283629059791565
- 0.08433927595615387
- 0.09031086415052414
max_val_scaled_root_mean_squared_error:
- 0.16603174805641174
- 0.1706741750240326
- 0.17729216814041138
- 0.15928615629673004
- 0.17651261389255524
min_learning_rate:
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
- 1.1979999726463575e-05
min_loss:
- 0.0025917908642441034
- 0.0025584164541214705
- 0.002499554306268692
- 0.0024954641703516245
- 0.0024449648335576057
min_scaled_mean_absolute_error:
- 0.00301542691886425
- 0.002976306015625596
- 0.002906341338530183
- 0.0029053455218672752
- 0.0028418991714715958
min_scaled_root_mean_squared_error:
- 0.011370769701898098
- 0.010693884454667568
- 0.009863811545073986
- 0.011626441963016987
- 0.011576841585338116
min_val_loss:
- 0.02578863501548767
- 0.02525782398879528
- 0.02570636384189129
- 0.02563677355647087
- 0.02539525367319584
min_val_scaled_mean_absolute_error:
- 0.030011145398020744
- 0.029425358399748802
- 0.029944349080324173
- 0.02989417500793934
- 0.029561277478933334
min_val_scaled_root_mean_squared_error:
- 0.07147400826215744
- 0.07076342403888702
- 0.07674837857484818
- 0.07776230573654175
- 0.07215573638677597
model_class: make_crystal_model
model_name: CGCNN
model_version: '2023-11-28'
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 0.0030343900434672832
- 0.002976306015625596
- 0.0029143390711396933
- 0.00292161270044744
- 0.002848853589966893
scaled_root_mean_squared_error:
- 0.011370769701898098
- 0.010693884454667568
- 0.00986681692302227
- 0.01163093838840723
- 0.011578156612813473
seed: 42
time_list:
- '17:20:18.229154'
- '17:24:47.711764'
- '17:27:54.787246'
- '17:29:50.163244'
- '17:24:21.128653'
val_loss:
- 0.02578863501548767
- 0.02525782398879528
- 0.025737645104527473
- 0.02563677355647087
- 0.02539525367319584
val_scaled_mean_absolute_error:
- 0.030011145398020744
- 0.029425358399748802
- 0.029976332560181618
- 0.02989417500793934
- 0.029561277478933334
val_scaled_root_mean_squared_error:
- 0.07228276133537292
- 0.07152009755373001
- 0.07757915556430817
- 0.07864636182785034
- 0.07326369732618332
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model": {"class_name": "make_crystal_model", "module_name": "kgcnn.literature.CGCNN", "config": {"name": "CGCNN", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "node_frac_coordinates", "dtype": "float64", "ragged": true}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "range_image", "dtype": "float32", "ragged": true}, {"shape": [3, 3], "name": "graph_lattice", "dtype": "float64", "ragged": false}], "input_tensor_type": "ragged", "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "representation": "unit", "expand_distance": true, "make_distances": true, "gauss_args": {"bins": 60, "distance": 6, "offset": 0.0, "sigma": 0.4}, "conv_layer_args": {"units": 128, "activation_s": "kgcnn>shifted_softplus", "activation_out": "kgcnn>shifted_softplus", "batch_normalization": true}, "node_pooling_args": {"pooling_method": "mean"}, "depth": 4, "output_mlp": {"use_bias": [true, true, false], "units": [128, 64, 1], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus", "linear"]}}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "fit": {"batch_size": 128, "epochs": 1000, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 1e-05, "epo_min": 500, "epo": 1000, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error"}, "scaler": {"class_name": "StandardLabelScaler", "module_name": "kgcnn.data.transform.scaler.standard", "config": {"with_std": true, "with_mean": true, "copy": true}}, "multi_target_indices": null}, "data": {"data_unit": "eV/atom"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "MatProjectEFormDataset", "module_name": "kgcnn.data.datasets.MatProjectEFormDataset", "config": {}, "methods": [{"map_list": {"method": "set_range_periodic", "max_distance": 6.0}}]}}
1 change: 1 addition & 0 deletions training/results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Materials Project dataset from Matbench with 132752 crystal structures and their

| model | kgcnn | epochs | MAE [eV/atom] | RMSE [eV/atom] |
|:--------------------------|:--------|---------:|:-----------------------|:-----------------------|
| CGCNN.make_crystal_model | 4.0.0 | 1000 | 0.0298 ± 0.0002 | 0.0747 ± 0.0029 |
| Schnet.make_crystal_model | 4.0.0 | 800 | **0.0211 ± 0.0003** | **0.0510 ± 0.0024** |

#### MatProjectGapDataset
Expand Down
Loading

0 comments on commit 307a42f

Please sign in to comment.