Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 15, 2023
1 parent a667d3a commit d884686
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 1 deletion.
4 changes: 3 additions & 1 deletion kgcnn/literature/GAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ def make_model(inputs: list = None,
Model inputs:
Model uses the list template of inputs and standard output template.
The supported inputs are :obj:`[nodes, edges, edge_indices, ...]`
with '...' indicating mask or id tensors following the template below:
with '...' indicating mask or ID tensors following the template below:
%s
Model outputs:
The standard output template:
%s
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
OS: nt_win32
backend: tensorflow
cuda_available: 'False'
data_unit: meV/atom
date_time: '2023-11-15 15:22:58'
device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]'
device_memory: '[]'
device_name: '[{}]'
epochs:
- 800
- 800
- 800
- 800
- 800
execute_folds: null
kgcnn_version: 4.0.0
learning_rate:
- 2.1409230612334795e-05
- 2.1409230612334795e-05
- 2.1409230612334795e-05
- 2.1409230612334795e-05
- 2.1409230612334795e-05
loss:
- 0.01342229824513197
- 0.01582755520939827
- 0.015029882080852985
- 0.016817400231957436
- 0.014265685342252254
max_learning_rate:
- 9.999999747378752e-05
- 9.999999747378752e-05
- 9.999999747378752e-05
- 9.999999747378752e-05
- 9.999999747378752e-05
max_loss:
- 0.5201069712638855
- 0.5315175652503967
- 0.451824426651001
- 0.5210411548614502
- 0.4851284623146057
max_scaled_mean_absolute_error:
- 75.4897232055664
- 74.12128448486328
- 61.5352897644043
- 67.2265625
- 57.95024871826172
max_scaled_root_mean_squared_error:
- 149.78330993652344
- 143.6664276123047
- 140.2177734375
- 131.99143981933594
- 122.83416748046875
max_val_loss:
- 0.2821546494960785
- 0.3245273530483246
- 0.3746875822544098
- 0.5050886273384094
- 0.618653416633606
max_val_scaled_mean_absolute_error:
- 40.90802001953125
- 45.22543716430664
- 51.08058166503906
- 65.11273956298828
- 73.94288635253906
max_val_scaled_root_mean_squared_error:
- 84.96401977539062
- 111.83113861083984
- 129.0309600830078
- 159.9274444580078
- 187.96551513671875
min_learning_rate:
- 9.99999993922529e-09
- 9.99999993922529e-09
- 9.99999993922529e-09
- 9.99999993922529e-09
- 9.99999993922529e-09
min_loss:
- 0.013281172141432762
- 0.01350078172981739
- 0.01428682915866375
- 0.013030022382736206
- 0.011631934903562069
min_scaled_mean_absolute_error:
- 1.9360378980636597
- 1.8882485628128052
- 1.9466769695281982
- 1.6665455102920532
- 1.3933351039886475
min_scaled_root_mean_squared_error:
- 9.1520357131958
- 7.898106575012207
- 9.539583206176758
- 9.564427375793457
- 6.542611122131348
min_val_loss:
- 0.21616201102733612
- 0.2801247239112854
- 0.3179992735385895
- 0.44068655371665955
- 0.5248962044715881
min_val_scaled_mean_absolute_error:
- 31.34011459350586
- 39.01124572753906
- 43.35005569458008
- 56.815731048583984
- 62.80753707885742
min_val_scaled_root_mean_squared_error:
- 68.39794158935547
- 100.04585266113281
- 115.25602722167969
- 143.70919799804688
- 162.13404846191406
model_class: make_crystal_model
model_name: PAiNN
model_version: '2023-10-04'
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 1.9560686349868774
- 2.2146451473236084
- 2.046118974685669
- 2.1549291610717773
- 1.7081328630447388
scaled_root_mean_squared_error:
- 9.223356246948242
- 8.082352638244629
- 9.63218879699707
- 9.677685737609863
- 6.588728427886963
seed: 42
time_list:
- '0:31:53.683697'
- '0:33:27.269425'
- '1:00:06.044732'
- '1:13:54.304826'
- '1:49:35.146251'
val_loss:
- 0.2428491860628128
- 0.29850465059280396
- 0.3226911723613739
- 0.4874231219291687
- 0.5294967293739319
val_scaled_mean_absolute_error:
- 35.20934295654297
- 41.602272033691406
- 43.98591613769531
- 62.783592224121094
- 63.3631706237793
val_scaled_root_mean_squared_error:
- 80.62264251708984
- 101.75128936767578
- 115.29129791259766
- 147.98414611816406
- 162.89393615722656
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model": {"module_name": "kgcnn.literature.PAiNN", "class_name": "make_crystal_model", "config": {"name": "PAiNN", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32", "ragged": true}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64", "ragged": true}, {"shape": [null, 3], "name": "range_image", "dtype": "int64", "ragged": true}, {"shape": [3, 3], "name": "graph_lattice", "dtype": "float32", "ragged": false}], "input_tensor_type": "ragged", "cast_disjoint_kwargs": {}, "input_node_embedding": {"input_dim": 95, "output_dim": 128}, "equiv_initialize_kwargs": {"dim": 3, "method": "eye"}, "bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5}, "pooling_args": {"pooling_method": "scatter_mean"}, "conv_args": {"units": 128, "cutoff": null, "conv_pool": "scatter_sum"}, "update_args": {"units": 128}, "depth": 2, "verbose": 10, "equiv_normalization": false, "node_normalization": false, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true], "units": [128, 1], "activation": ["swish", "linear"]}}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearWarmupLinearLearningRateScheduler", "config": {"learning_rate_start": 0.0001, "learning_rate_stop": 1e-06, "epo_warmup": 25, "epo": 1000, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0001}}, "loss": "mean_absolute_error"}, "scaler": {"class_name": "StandardLabelScaler", "module_name": "kgcnn.data.transform.scaler.standard", "config": {"with_std": true, "with_mean": true, "copy": true}}, "multi_target_indices": null}, "data": {"data_unit": "meV/atom"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "MatProjectJdft2dDataset", "module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset", "config": {}, "methods": [{"map_list": {"method": "set_range_periodic", "max_distance": 5.0}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices", "count_nodes": "node_number", "total_nodes": "total_nodes"}}]}}
1 change: 1 addition & 0 deletions training/results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Materials Project dataset from Matbench with 636 crystal structures and their co

| model | kgcnn | epochs | MAE [meV/atom] | RMSE [meV/atom] |
|:--------------------------|:--------|---------:|:-------------------------|:--------------------------|
| PAiNN.make_crystal_model | 4.0.0 | 800 | 49.3889 ± 11.5376 | 121.7087 ± 30.0472 |
| Schnet.make_crystal_model | 4.0.0 | 800 | **45.2412 ± 11.6395** | **115.6890 ± 39.0929** |

#### MatProjectLogGVRHDataset
Expand Down

0 comments on commit d884686

Please sign in to comment.