Skip to content

Commit

Permalink
refactoring for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Oct 20, 2023
1 parent dc4f017 commit 9cc8190
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 46 deletions.
7 changes: 4 additions & 3 deletions training/hyper/hyper_iso17.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,16 @@
"batch_size": 64, "epochs": 1000, "validation_freq": 1, "verbose": 2,
"callbacks": [
{"class_name": "kgcnn>LinearWarmupExponentialLRScheduler", "config": {
"lr_start": 1e-03, "gamma": 0.995, "epo_warmup": 1, "verbose": 1, "steps_per_epoch": 8062}}
"lr_start": 1e-03,
"gamma": 0.995, "epo_warmup": 1, "verbose": 1, "steps_per_epoch": 8062}}
]
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss_weights": {"energy": 1.0, "force": 49.0}
"loss_weights": {"energy": 0.02, "force": 0.98}
},
"scaler": {"class_name": "EnergyForceExtensiveLabelScaler",
"config": {"standardize_scale": False}},
"config": {"standardize_scale": True}},
},
"data": {
"dataset": {
Expand Down
38 changes: 19 additions & 19 deletions training/hyper/hyper_md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,30 @@
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss_weights": {"energy": 0.1, "force": 4.9}
"loss_weights": {"energy": 0.02, "force": 0.98}
},
"scaler": {"class_name": "EnergyForceExtensiveLabelScaler",
"config": {"standardize_scale": True}},
},
"data": {
"dataset": {
"class_name": "MD17Dataset",
"module_name": "kgcnn.data.datasets.MD17Dataset",
"config": {
"trajectory_name": trajectory_name
},
"methods": [
{"rename_property_on_graphs": {"old_property_name": "E", "new_property_name": "energy"}},
{"rename_property_on_graphs": {"old_property_name": "F", "new_property_name": "force"}},
{"rename_property_on_graphs": {"old_property_name": "z", "new_property_name": "atomic_number"}},
{"rename_property_on_graphs": {"old_property_name": "R", "new_property_name": "node_coordinates"}},
{"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000,
"node_coordinates": "node_coordinates"}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "atomic_number",
"total_nodes": "total_nodes"}},
]
"dataset": {
"class_name": "MD17Dataset",
"module_name": "kgcnn.data.datasets.MD17Dataset",
"config": {
"trajectory_name": trajectory_name
},
"methods": [
{"rename_property_on_graphs": {"old_property_name": "E", "new_property_name": "energy"}},
{"rename_property_on_graphs": {"old_property_name": "F", "new_property_name": "force"}},
{"rename_property_on_graphs": {"old_property_name": "z", "new_property_name": "atomic_number"}},
{"rename_property_on_graphs": {"old_property_name": "R", "new_property_name": "node_coordinates"}},
{"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000,
"node_coordinates": "node_coordinates"}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "atomic_number",
"total_nodes": "total_nodes"}},
]
},
"data": {
},
"info": {
"postfix": "",
Expand Down
44 changes: 22 additions & 22 deletions training/hyper/hyper_md17_revised.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,33 @@
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss_weights": {"energy": 1.0, "force": 49.0}
"loss_weights": {"energy": 0.02, "force": 0.98}
},
"scaler": {"class_name": "EnergyForceExtensiveLabelScaler",
"config": {"standardize_scale": True}},
},
"data": {
"dataset": {
"class_name": "MD17RevisedDataset",
"module_name": "kgcnn.data.datasets.MD17RevisedDataset",
"config": {
# toluene, aspirin, malonaldehyde, benzene, ethanol
"trajectory_name": trajectory_name
},
"methods": [
{"rename_property_on_graphs": {"old_property_name": "energies", "new_property_name": "energy"}},
{"rename_property_on_graphs": {"old_property_name": "forces", "new_property_name": "force"}},
{"rename_property_on_graphs": {"old_property_name": "nuclear_charges",
"new_property_name": "atomic_number"}},
{"rename_property_on_graphs": {"old_property_name": "coords",
"new_property_name": "node_coordinates"}},
{"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000,
"node_coordinates": "node_coordinates"}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "atomic_number",
"total_nodes": "total_nodes"}},
]
"dataset": {
"class_name": "MD17RevisedDataset",
"module_name": "kgcnn.data.datasets.MD17RevisedDataset",
"config": {
# toluene, aspirin, malonaldehyde, benzene, ethanol
"trajectory_name": trajectory_name
},
"methods": [
{"rename_property_on_graphs": {"old_property_name": "energies", "new_property_name": "energy"}},
{"rename_property_on_graphs": {"old_property_name": "forces", "new_property_name": "force"}},
{"rename_property_on_graphs": {"old_property_name": "nuclear_charges",
"new_property_name": "atomic_number"}},
{"rename_property_on_graphs": {"old_property_name": "coords",
"new_property_name": "node_coordinates"}},
{"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 10000,
"node_coordinates": "node_coordinates"}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "atomic_number",
"total_nodes": "total_nodes"}},
]
},
"data": {
},
"info": {
"postfix": "",
Expand Down
7 changes: 5 additions & 2 deletions training/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,13 @@

# Compile model with optimizer and loss
model.compile(**hyper.compile(
loss={"energy": "mean_absolute_error", "force": ForceMeanAbsoluteError()},
loss={
"energy": "mean_absolute_error",
"force": "mean_absolute_error" # ForceMeanAbsoluteError()
},
metrics=scaled_metrics))

print(model.predict(x_test))
model.predict(x_test)
# Model summary
model.summary()
print(" Compiled with jit: %s" % model._jit_compile) # noqa
Expand Down

0 comments on commit 9cc8190

Please sign in to comment.