diff --git a/kgcnn/utils/plots.py b/kgcnn/utils/plots.py index ccd799ee..bf07042e 100644 --- a/kgcnn/utils/plots.py +++ b/kgcnn/utils/plots.py @@ -37,11 +37,15 @@ def plot_train_test_loss(histories: list, loss_name: str = None, loss_name = [x for x in list(histories[0].keys()) if "val_" not in x] if val_loss_name is None: val_loss_name = [x for x in list(histories[0].keys()) if "val_" in x] - if not isinstance(loss_name, list): loss_name = [loss_name] if not isinstance(val_loss_name, list): val_loss_name = [val_loss_name] + if not isinstance(data_unit, list): + data_unit = [data_unit] + + if len(data_unit) < len(val_loss_name): + data_unit = data_unit + [str(data_unit[-1])]*(len(val_loss_name)-len(data_unit)) train_loss = [] for x in loss_name: @@ -74,7 +78,7 @@ def plot_train_test_loss(histories: list, loss_name: str = None, plt.scatter([len(train_loss[i][0])], [np.mean(y, axis=0)[-1]], label=r"{0}: {1:0.4f} $\pm$ {2:0.4f} ".format( val_loss_name[i], np.mean(y, axis=0)[-1], - np.std(y, axis=0)[-1]) + data_unit, color=vp[0].get_color() + np.std(y, axis=0)[-1]) + data_unit[i], color=vp[0].get_color() ) plt.xlabel('Epochs') plt.ylabel('Loss') diff --git a/training/results/QM9Dataset/Schnet/Schnet_QM9Dataset_score_U.yaml b/training/results/QM9Dataset/Schnet/Schnet_QM9Dataset_score_U.yaml new file mode 100644 index 00000000..2275372d --- /dev/null +++ b/training/results/QM9Dataset/Schnet/Schnet_QM9Dataset_score_U.yaml @@ -0,0 +1,270 @@ +OS: nt_win32 +backend: tensorflow +cuda_available: 'False' +data_unit: '[''eV'']' +date_time: '2023-10-28 15:35:41' +device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]' +device_memory: '[]' +device_name: '[{}]' +epochs: +- 800 +- 800 +- 800 +- 800 +- 800 +- 800 +- 800 +- 800 +- 800 +- 800 +execute_folds: [] +kgcnn_version: 4.0.0 +learning_rate: +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +loss: +- 0.003102271119132638 +- 0.0031446837820112705 +- 0.003222967963665724 +- 0.0031532456632703543 +- 0.003179531078785658 +- 0.003303215140476823 +- 0.0032826880924403667 +- 0.0033122405875474215 +- 0.003219034057110548 +- 0.0031931130215525627 +max_learning_rate: +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +max_loss: +- 0.32900306582450867 +- 0.31977978348731995 +- 0.315213680267334 +- 0.3267098367214203 +- 0.32087260484695435 +- 0.32263708114624023 +- 0.33119043707847595 +- 0.32553577423095703 +- 0.33107835054397583 +- 0.3280688524246216 +max_scaled_mean_absolute_error: +- 0.32905193320264625 +- 0.3198247700749311 +- 0.3152591766524487 +- 0.3267146232519014 +- 0.32090585330008337 +- 0.32266499042006824 +- 0.331219953791748 +- 0.3255783607024572 +- 0.33108641226527835 +- 0.3280781928206874 +max_scaled_root_mean_squared_error: +- 0.48803940748218383 +- 0.4722164986112966 +- 0.46687197334126235 +- 0.4824642231699354 +- 0.47237326096568216 +- 0.4767886644757702 +- 0.48718154684191073 +- 0.48111225911341654 +- 0.4896235333899374 +- 0.48458047356738754 +max_val_loss: +- 0.13950803875923157 +- 0.09146694093942642 +- 0.08519700914621353 +- 0.09302295744419098 +- 0.07964996248483658 +- 0.09005503356456757 +- 0.08714553713798523 +- 0.08930784463882446 +- 0.14616762101650238 +- 0.08222486078739166 +max_val_scaled_mean_absolute_error: +- 0.1384251979366305 +- 0.09089954509881863 +- 0.08493482909156742 +- 0.09195206884545477 +- 0.07788716696188017 +- 0.08892871728155785 +- 0.08499441338207311 +- 0.08826035654976387 +- 0.14532977491327478 +- 0.08147701913348938 +max_val_scaled_root_mean_squared_error: +- 0.18370449513677514 +- 0.13094710846372856 +- 0.1282176691272395 +- 0.13305046672874024 +- 0.13000283246683073 +- 0.13205004523360545 +- 0.13780788845402625 +- 0.127828771209062 +- 0.1859635789862445 +- 0.12361408291501223 +min_learning_rate: +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +- 1.06999996205559e-05 +min_loss: +- 0.003102271119132638 +- 0.0031446837820112705 +- 0.003222967963665724 +- 0.0031532456632703543 +- 0.003179531078785658 +- 0.003303215140476823 +- 0.0032826880924403667 +- 0.0033122405875474215 +- 0.003219034057110548 +- 0.0031931130215525627 +min_scaled_mean_absolute_error: +- 0.00310240541088055 +- 0.003145102704026026 +- 0.003223467196040231 +- 0.0031533548883815768 +- 0.0031800479318645027 +- 0.0033036366300158255 +- 0.0032831822678082764 +- 0.003312745464602815 +- 0.0032194061774542034 +- 0.0031933249947576877 +min_scaled_root_mean_squared_error: +- 0.004267732501010393 +- 0.004246839030432776 +- 0.004458374770604839 +- 0.004275395080879121 +- 0.00431145122420274 +- 0.005204495085082492 +- 0.004462491839127707 +- 0.0046641875080630545 +- 0.0045305375165805035 +- 0.004296213755401796 +min_val_loss: +- 0.014781356789171696 +- 0.012920898385345936 +- 0.01292756199836731 +- 0.014434972777962685 +- 0.014712795615196228 +- 0.013618988916277885 +- 0.014511722140014172 +- 0.014128431677818298 +- 0.012978708371520042 +- 0.013119565322995186 +min_val_scaled_mean_absolute_error: +- 0.014120716750712173 +- 0.012860281101251151 +- 0.012784171533398017 +- 0.013679683258432147 +- 0.013684824596960582 +- 0.0129447451102775 +- 0.013640340616333098 +- 0.013602984113398517 +- 0.012948668216790582 +- 0.012709200020117863 +min_val_scaled_root_mean_squared_error: +- 0.049919464270817504 +- 0.03026642499533745 +- 0.03513453907510313 +- 0.04346468396620625 +- 0.039105222914479355 +- 0.035136985111746795 +- 0.04161567543048673 +- 0.03471576091403692 +- 0.02727225442509016 +- 0.03069353773982778 +model_class: make_model +model_name: Schnet +model_version: '' +multi_target_indices: +- 11 +number_histories: 10 +scaled_mean_absolute_error: +- 0.00310240541088055 +- 0.003145102704026026 +- 0.003223467196040231 +- 0.0031533548883815768 +- 0.0031800479318645027 +- 0.0033036366300158255 +- 0.0032831822678082764 +- 0.003312745464602815 +- 0.0032194061774542034 +- 0.0031933249947576877 +scaled_root_mean_squared_error: +- 0.004267732501010393 +- 0.004246839030432776 +- 0.004458374770604839 +- 0.004275395080879121 +- 0.00431145122420274 +- 0.005204495085082492 +- 0.004462491839127707 +- 0.0046641875080630545 +- 0.0045305375165805035 +- 0.004296213755401796 +seed: 42 +time_list: +- '7:59:50.696123' +- '7:48:22.017206' +- '7:55:55.432746' +- '7:49:30.672315' +- '7:55:53.938183' +- '10:09:28.508289' +- '8:52:12.519615' +- '8:38:43.179948' +- '8:09:18.160746' +- '8:25:30.579937' +val_loss: +- 0.015042244456708431 +- 0.012920898385345936 +- 0.01292756199836731 +- 0.014434972777962685 +- 0.014712795615196228 +- 0.013618988916277885 +- 0.014511722140014172 +- 0.014128431677818298 +- 0.013102271594107151 +- 0.013119565322995186 +val_scaled_mean_absolute_error: +- 0.014382796675197521 +- 0.012860281101251151 +- 0.012784171533398017 +- 0.013679683258432147 +- 0.013684824596960582 +- 0.0129447451102775 +- 0.013640340616333098 +- 0.013602984113398517 +- 0.013072252945093014 +- 0.012709200020117863 +val_scaled_root_mean_squared_error: +- 0.050914259932424766 +- 0.030465935479548335 +- 0.03528550774068547 +- 0.0436440285448667 +- 0.039105222914479355 +- 0.035136985111746795 +- 0.04161567543048673 +- 0.03471576091403692 +- 0.027616748351973794 +- 0.03166351491351372 diff --git a/training/results/QM9Dataset/Schnet/Schnet_hyper_U.json b/training/results/QM9Dataset/Schnet/Schnet_hyper_U.json new file mode 100644 index 00000000..cba87e49 --- /dev/null +++ b/training/results/QM9Dataset/Schnet/Schnet_hyper_U.json @@ -0,0 +1 @@ +{"model": {"class_name": "make_model", "module_name": "kgcnn.literature.Schnet", "config": {"name": "Schnet", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int32"}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32"}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_ranges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": false}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "last_mlp": {"use_bias": [true, true, true], "units": [128, 64, 1], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus", "linear"]}, "interaction_args": {"units": 128, "use_bias": true, "activation": "kgcnn>shifted_softplus", "cfconv_pool": "scatter_sum"}, "node_pooling_args": {"pooling_method": "scatter_sum"}, "depth": 4, "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, "verbose": 10, "output_embedding": "graph", "use_output_mlp": false, "output_mlp": null, "output_scaling": {"name": "ExtensiveMolecularLabelScaler"}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0005}}, "loss": {"class_name": "kgcnn>MeanAbsoluteError", "config": {"dtype": "float64"}}, "metrics": [{"class_name": "MeanAbsoluteError", "config": {"dtype": "float64", "name": "scaled_mean_absolute_error"}}, {"class_name": "RootMeanSquaredError", "config": {"dtype": "float64", "name": "scaled_root_mean_squared_error"}}]}, "scaler": {"class_name": "ExtensiveMolecularLabelScaler", "config": {"atomic_number": "node_number"}}, "multi_target_indices": [11]}, "data": {}, "info": {"postfix": "", "postfix_file": "_U", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "QM9Dataset", "module_name": "kgcnn.data.datasets.QM9Dataset", "config": {}, "methods": [{"set_train_test_indices_k_fold": {"n_splits": 10, "random_state": 42, "shuffle": true}}, {"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 30}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices", "count_nodes": "node_number", "total_nodes": "total_nodes"}}]}} \ No newline at end of file diff --git a/training/train_graph.py b/training/train_graph.py index 8b6922fb..684ffff4 100644 --- a/training/train_graph.py +++ b/training/train_graph.py @@ -96,7 +96,7 @@ train_indices_all.append(train_index) # Only do execute_splits out of the k-folds of cross-validation. - if execute_folds: + if execute_folds is not None: if current_split not in execute_folds: continue print("Running training on split: '%s'." % current_split)