diff --git a/training/hyper/hyper_iso17.py b/training/hyper/hyper_iso17.py index 5c4b87a8..4e436744 100644 --- a/training/hyper/hyper_iso17.py +++ b/training/hyper/hyper_iso17.py @@ -53,15 +53,16 @@ "batch_size": 64, "epochs": 1000, "validation_freq": 1, "verbose": 2, "callbacks": [ {"class_name": "kgcnn>LinearWarmupExponentialLRScheduler", "config": { - "lr_start": 1e-03, "gamma": 0.995, "epo_warmup": 1, "verbose": 1, "steps_per_epoch": 8062}} + "lr_start": 1e-03, + "gamma": 0.995, "epo_warmup": 1, "verbose": 1, "steps_per_epoch": 8062}} ] }, "compile": { "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}}, - "loss_weights": {"energy": 1.0, "force": 49.0} + "loss_weights": {"energy": 0.02, "force": 0.98} }, "scaler": {"class_name": "EnergyForceExtensiveLabelScaler", - "config": {"standardize_scale": False}}, + "config": {"standardize_scale": True}}, }, "data": { "dataset": { diff --git a/training/hyper/hyper_md17.py b/training/hyper/hyper_md17.py index e5db4b74..9273a64c 100644 --- a/training/hyper/hyper_md17.py +++ b/training/hyper/hyper_md17.py @@ -61,30 +61,30 @@ }, "compile": { "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}}, - "loss_weights": {"energy": 0.1, "force": 4.9} + "loss_weights": {"energy": 0.02, "force": 0.98} }, "scaler": {"class_name": "EnergyForceExtensiveLabelScaler", "config": {"standardize_scale": True}}, }, - "data": { - "dataset": { - "class_name": "MD17Dataset", - "module_name": "kgcnn.data.datasets.MD17Dataset", - "config": { - "trajectory_name": trajectory_name - }, - "methods": [ - {"rename_property_on_graphs": {"old_property_name": "E", "new_property_name": "energy"}}, - {"rename_property_on_graphs": {"old_property_name": "F", "new_property_name": "force"}}, - {"rename_property_on_graphs": {"old_property_name": "z", "new_property_name": "atomic_number"}}, - {"rename_property_on_graphs": {"old_property_name": "R", "new_property_name": "node_coordinates"}}, - {"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000, - "node_coordinates": "node_coordinates"}}, - {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", - "count_edges": "range_indices", "count_nodes": "atomic_number", - "total_nodes": "total_nodes"}}, - ] + "dataset": { + "class_name": "MD17Dataset", + "module_name": "kgcnn.data.datasets.MD17Dataset", + "config": { + "trajectory_name": trajectory_name }, + "methods": [ + {"rename_property_on_graphs": {"old_property_name": "E", "new_property_name": "energy"}}, + {"rename_property_on_graphs": {"old_property_name": "F", "new_property_name": "force"}}, + {"rename_property_on_graphs": {"old_property_name": "z", "new_property_name": "atomic_number"}}, + {"rename_property_on_graphs": {"old_property_name": "R", "new_property_name": "node_coordinates"}}, + {"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000, + "node_coordinates": "node_coordinates"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices", "count_nodes": "atomic_number", + "total_nodes": "total_nodes"}}, + ] + }, + "data": { }, "info": { "postfix": "", diff --git a/training/hyper/hyper_md17_revised.py b/training/hyper/hyper_md17_revised.py index 412b4231..3d3acf4f 100644 --- a/training/hyper/hyper_md17_revised.py +++ b/training/hyper/hyper_md17_revised.py @@ -61,33 +61,33 @@ }, "compile": { "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}}, - "loss_weights": {"energy": 1.0, "force": 49.0} + "loss_weights": {"energy": 0.02, "force": 0.98} }, "scaler": {"class_name": "EnergyForceExtensiveLabelScaler", "config": {"standardize_scale": True}}, }, - "data": { - "dataset": { - "class_name": "MD17RevisedDataset", - "module_name": "kgcnn.data.datasets.MD17RevisedDataset", - "config": { - # toluene, aspirin, malonaldehyde, benzene, ethanol - "trajectory_name": trajectory_name - }, - "methods": [ - {"rename_property_on_graphs": {"old_property_name": "energies", "new_property_name": "energy"}}, - {"rename_property_on_graphs": {"old_property_name": "forces", "new_property_name": "force"}}, - {"rename_property_on_graphs": {"old_property_name": "nuclear_charges", - "new_property_name": "atomic_number"}}, - {"rename_property_on_graphs": {"old_property_name": "coords", - "new_property_name": "node_coordinates"}}, - {"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000, - "node_coordinates": "node_coordinates"}}, - {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", - "count_edges": "range_indices", "count_nodes": "atomic_number", - "total_nodes": "total_nodes"}}, - ] + "dataset": { + "class_name": "MD17RevisedDataset", + "module_name": "kgcnn.data.datasets.MD17RevisedDataset", + "config": { + # toluene, aspirin, malonaldehyde, benzene, ethanol + "trajectory_name": trajectory_name }, + "methods": [ + {"rename_property_on_graphs": {"old_property_name": "energies", "new_property_name": "energy"}}, + {"rename_property_on_graphs": {"old_property_name": "forces", "new_property_name": "force"}}, + {"rename_property_on_graphs": {"old_property_name": "nuclear_charges", + "new_property_name": "atomic_number"}}, + {"rename_property_on_graphs": {"old_property_name": "coords", + "new_property_name": "node_coordinates"}}, + {"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000, + "node_coordinates": "node_coordinates"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices", "count_nodes": "atomic_number", + "total_nodes": "total_nodes"}}, + ] + }, + "data": { }, "info": { "postfix": "", diff --git a/training/train_force.py b/training/train_force.py index d568b808..49f9aa0b 100644 --- a/training/train_force.py +++ b/training/train_force.py @@ -148,10 +148,13 @@ # Compile model with optimizer and loss model.compile(**hyper.compile( - loss={"energy": "mean_absolute_error", "force": ForceMeanAbsoluteError()}, + loss={ + "energy": "mean_absolute_error", + "force": "mean_absolute_error" # ForceMeanAbsoluteError() + }, metrics=scaled_metrics)) - print(model.predict(x_test)) + model.predict(x_test) # Model summary model.summary() print(" Compiled with jit: %s" % model._jit_compile) # noqa