Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Oct 28, 2023
1 parent 3c1a906 commit 00ecd62
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 3 deletions.
8 changes: 6 additions & 2 deletions kgcnn/utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ def plot_train_test_loss(histories: list, loss_name: str = None,
loss_name = [x for x in list(histories[0].keys()) if "val_" not in x]
if val_loss_name is None:
val_loss_name = [x for x in list(histories[0].keys()) if "val_" in x]

if not isinstance(loss_name, list):
loss_name = [loss_name]
if not isinstance(val_loss_name, list):
val_loss_name = [val_loss_name]
if not isinstance(data_unit, list):
data_unit = [data_unit]

if len(data_unit) < len(val_loss_name):
data_unit = data_unit + [str(data_unit[-1])]*(len(val_loss_name)-len(data_unit))

train_loss = []
for x in loss_name:
Expand Down Expand Up @@ -74,7 +78,7 @@ def plot_train_test_loss(histories: list, loss_name: str = None,
plt.scatter([len(train_loss[i][0])], [np.mean(y, axis=0)[-1]],
label=r"{0}: {1:0.4f} $\pm$ {2:0.4f} ".format(
val_loss_name[i], np.mean(y, axis=0)[-1],
np.std(y, axis=0)[-1]) + data_unit, color=vp[0].get_color()
np.std(y, axis=0)[-1]) + data_unit[i], color=vp[0].get_color()
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
Expand Down
270 changes: 270 additions & 0 deletions training/results/QM9Dataset/Schnet/Schnet_QM9Dataset_score_U.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
OS: nt_win32
backend: tensorflow
cuda_available: 'False'
data_unit: '[''eV'']'
date_time: '2023-10-28 15:35:41'
device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]'
device_memory: '[]'
device_name: '[{}]'
epochs:
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
execute_folds: []
kgcnn_version: 4.0.0
learning_rate:
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
loss:
- 0.003102271119132638
- 0.0031446837820112705
- 0.003222967963665724
- 0.0031532456632703543
- 0.003179531078785658
- 0.003303215140476823
- 0.0032826880924403667
- 0.0033122405875474215
- 0.003219034057110548
- 0.0031931130215525627
max_learning_rate:
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
- 0.0005000000237487257
max_loss:
- 0.32900306582450867
- 0.31977978348731995
- 0.315213680267334
- 0.3267098367214203
- 0.32087260484695435
- 0.32263708114624023
- 0.33119043707847595
- 0.32553577423095703
- 0.33107835054397583
- 0.3280688524246216
max_scaled_mean_absolute_error:
- 0.32905193320264625
- 0.3198247700749311
- 0.3152591766524487
- 0.3267146232519014
- 0.32090585330008337
- 0.32266499042006824
- 0.331219953791748
- 0.3255783607024572
- 0.33108641226527835
- 0.3280781928206874
max_scaled_root_mean_squared_error:
- 0.48803940748218383
- 0.4722164986112966
- 0.46687197334126235
- 0.4824642231699354
- 0.47237326096568216
- 0.4767886644757702
- 0.48718154684191073
- 0.48111225911341654
- 0.4896235333899374
- 0.48458047356738754
max_val_loss:
- 0.13950803875923157
- 0.09146694093942642
- 0.08519700914621353
- 0.09302295744419098
- 0.07964996248483658
- 0.09005503356456757
- 0.08714553713798523
- 0.08930784463882446
- 0.14616762101650238
- 0.08222486078739166
max_val_scaled_mean_absolute_error:
- 0.1384251979366305
- 0.09089954509881863
- 0.08493482909156742
- 0.09195206884545477
- 0.07788716696188017
- 0.08892871728155785
- 0.08499441338207311
- 0.08826035654976387
- 0.14532977491327478
- 0.08147701913348938
max_val_scaled_root_mean_squared_error:
- 0.18370449513677514
- 0.13094710846372856
- 0.1282176691272395
- 0.13305046672874024
- 0.13000283246683073
- 0.13205004523360545
- 0.13780788845402625
- 0.127828771209062
- 0.1859635789862445
- 0.12361408291501223
min_learning_rate:
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
- 1.06999996205559e-05
min_loss:
- 0.003102271119132638
- 0.0031446837820112705
- 0.003222967963665724
- 0.0031532456632703543
- 0.003179531078785658
- 0.003303215140476823
- 0.0032826880924403667
- 0.0033122405875474215
- 0.003219034057110548
- 0.0031931130215525627
min_scaled_mean_absolute_error:
- 0.00310240541088055
- 0.003145102704026026
- 0.003223467196040231
- 0.0031533548883815768
- 0.0031800479318645027
- 0.0033036366300158255
- 0.0032831822678082764
- 0.003312745464602815
- 0.0032194061774542034
- 0.0031933249947576877
min_scaled_root_mean_squared_error:
- 0.004267732501010393
- 0.004246839030432776
- 0.004458374770604839
- 0.004275395080879121
- 0.00431145122420274
- 0.005204495085082492
- 0.004462491839127707
- 0.0046641875080630545
- 0.0045305375165805035
- 0.004296213755401796
min_val_loss:
- 0.014781356789171696
- 0.012920898385345936
- 0.01292756199836731
- 0.014434972777962685
- 0.014712795615196228
- 0.013618988916277885
- 0.014511722140014172
- 0.014128431677818298
- 0.012978708371520042
- 0.013119565322995186
min_val_scaled_mean_absolute_error:
- 0.014120716750712173
- 0.012860281101251151
- 0.012784171533398017
- 0.013679683258432147
- 0.013684824596960582
- 0.0129447451102775
- 0.013640340616333098
- 0.013602984113398517
- 0.012948668216790582
- 0.012709200020117863
min_val_scaled_root_mean_squared_error:
- 0.049919464270817504
- 0.03026642499533745
- 0.03513453907510313
- 0.04346468396620625
- 0.039105222914479355
- 0.035136985111746795
- 0.04161567543048673
- 0.03471576091403692
- 0.02727225442509016
- 0.03069353773982778
model_class: make_model
model_name: Schnet
model_version: ''
multi_target_indices:
- 11
number_histories: 10
scaled_mean_absolute_error:
- 0.00310240541088055
- 0.003145102704026026
- 0.003223467196040231
- 0.0031533548883815768
- 0.0031800479318645027
- 0.0033036366300158255
- 0.0032831822678082764
- 0.003312745464602815
- 0.0032194061774542034
- 0.0031933249947576877
scaled_root_mean_squared_error:
- 0.004267732501010393
- 0.004246839030432776
- 0.004458374770604839
- 0.004275395080879121
- 0.00431145122420274
- 0.005204495085082492
- 0.004462491839127707
- 0.0046641875080630545
- 0.0045305375165805035
- 0.004296213755401796
seed: 42
time_list:
- '7:59:50.696123'
- '7:48:22.017206'
- '7:55:55.432746'
- '7:49:30.672315'
- '7:55:53.938183'
- '10:09:28.508289'
- '8:52:12.519615'
- '8:38:43.179948'
- '8:09:18.160746'
- '8:25:30.579937'
val_loss:
- 0.015042244456708431
- 0.012920898385345936
- 0.01292756199836731
- 0.014434972777962685
- 0.014712795615196228
- 0.013618988916277885
- 0.014511722140014172
- 0.014128431677818298
- 0.013102271594107151
- 0.013119565322995186
val_scaled_mean_absolute_error:
- 0.014382796675197521
- 0.012860281101251151
- 0.012784171533398017
- 0.013679683258432147
- 0.013684824596960582
- 0.0129447451102775
- 0.013640340616333098
- 0.013602984113398517
- 0.013072252945093014
- 0.012709200020117863
val_scaled_root_mean_squared_error:
- 0.050914259932424766
- 0.030465935479548335
- 0.03528550774068547
- 0.0436440285448667
- 0.039105222914479355
- 0.035136985111746795
- 0.04161567543048673
- 0.03471576091403692
- 0.027616748351973794
- 0.03166351491351372
1 change: 1 addition & 0 deletions training/results/QM9Dataset/Schnet/Schnet_hyper_U.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature.Schnet", "config": {"name": "Schnet", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int32"}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32"}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_ranges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": false}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "last_mlp": {"use_bias": [true, true, true], "units": [128, 64, 1], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus", "linear"]}, "interaction_args": {"units": 128, "use_bias": true, "activation": "kgcnn>shifted_softplus", "cfconv_pool": "scatter_sum"}, "node_pooling_args": {"pooling_method": "scatter_sum"}, "depth": 4, "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, "verbose": 10, "output_embedding": "graph", "use_output_mlp": false, "output_mlp": null, "output_scaling": {"name": "ExtensiveMolecularLabelScaler"}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0005}}, "loss": {"class_name": "kgcnn>MeanAbsoluteError", "config": {"dtype": "float64"}}, "metrics": [{"class_name": "MeanAbsoluteError", "config": {"dtype": "float64", "name": "scaled_mean_absolute_error"}}, {"class_name": "RootMeanSquaredError", "config": {"dtype": "float64", "name": "scaled_root_mean_squared_error"}}]}, "scaler": {"class_name": "ExtensiveMolecularLabelScaler", "config": {"atomic_number": "node_number"}}, "multi_target_indices": [11]}, "data": {}, "info": {"postfix": "", "postfix_file": "_U", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "QM9Dataset", "module_name": "kgcnn.data.datasets.QM9Dataset", "config": {}, "methods": [{"set_train_test_indices_k_fold": {"n_splits": 10, "random_state": 42, "shuffle": true}}, {"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 30}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices", "count_nodes": "node_number", "total_nodes": "total_nodes"}}]}}
2 changes: 1 addition & 1 deletion training/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
train_indices_all.append(train_index)

# Only do execute_splits out of the k-folds of cross-validation.
if execute_folds:
if execute_folds is not None:
if current_split not in execute_folds:
continue
print("Running training on split: '%s'." % current_split)
Expand Down

0 comments on commit 00ecd62

Please sign in to comment.