From 15af06e9e2581fa5d458e0a239f0097081756287 Mon Sep 17 00:00:00 2001 From: PatReis Date: Tue, 14 Nov 2023 17:38:35 +0100 Subject: [PATCH] update for keras 3.0 --- kgcnn/layers/casting.py | 8 + kgcnn/literature/PAiNN/__init__.py | 6 +- kgcnn/literature/PAiNN/_layers.py | 6 +- kgcnn/literature/PAiNN/_make.py | 37 ++-- training/hyper/hyper_mp_jdft2d.py | 69 ++++++ .../Schnet_MatProjectJdft2dDataset_score.yaml | 202 +++++++++--------- training/results/README.md | 2 +- training/train_graph.py | 2 +- 8 files changed, 212 insertions(+), 120 deletions(-) diff --git a/kgcnn/layers/casting.py b/kgcnn/layers/casting.py index e372ccc0..0c96f027 100644 --- a/kgcnn/layers/casting.py +++ b/kgcnn/layers/casting.py @@ -550,6 +550,14 @@ def compute_output_shape(self, input_shape): out_gn, out_id_n, out_size_n = out_n[:1], out_n[:1], input_shape[:1] return out_n, out_gn, out_id_n, out_size_n + def compute_output_spec(self, inputs_spec): + """Compute output spec as possible.""" + output_shape = self.compute_output_shape(inputs_spec.shape) + dtype_batch = self.dtype_batch + output_dtypes = [inputs_spec[0].dtype, dtype_batch, dtype_batch, dtype_batch] + output_spec = [ks.KerasTensor(s, dtype=d) for s, d in zip(output_shape, output_dtypes)] + return output_spec + def build(self, input_shape): self.built = True diff --git a/kgcnn/literature/PAiNN/__init__.py b/kgcnn/literature/PAiNN/__init__.py index 547d2889..69e719cb 100644 --- a/kgcnn/literature/PAiNN/__init__.py +++ b/kgcnn/literature/PAiNN/__init__.py @@ -1,10 +1,10 @@ from ._make import make_model, model_default -# from ._make import make_crystal_model, model_crystal_default +from ._make import make_crystal_model, model_crystal_default __all__ = [ "make_model", "model_default", - # "make_crystal_model", - # "model_crystal_default" + "make_crystal_model", + "model_crystal_default" ] diff --git a/kgcnn/literature/PAiNN/_layers.py b/kgcnn/literature/PAiNN/_layers.py index 643ac9c3..881b310f 100644 --- a/kgcnn/literature/PAiNN/_layers.py +++ b/kgcnn/literature/PAiNN/_layers.py @@ -119,7 +119,8 @@ def get_config(self): config_dense = self.lay_dense1.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]: - config.update({x: config_dense[x]}) + if x in config_dense: + config.update({x: config_dense[x]}) return config @@ -220,7 +221,8 @@ def get_config(self): config_dense = self.lay_dense1.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]: - config.update({x: config_dense[x]}) + if x in config_dense: + config.update({x: config_dense[x]}) return config diff --git a/kgcnn/literature/PAiNN/_make.py b/kgcnn/literature/PAiNN/_make.py index 1c7f414f..df40baed 100644 --- a/kgcnn/literature/PAiNN/_make.py +++ b/kgcnn/literature/PAiNN/_make.py @@ -44,7 +44,8 @@ "depth": 3, "verbose": 10, "output_embedding": "graph", - "output_to_tensor": True, + "output_to_tensor": True, "output_tensor_type": "padded", + "output_scaling": None, "output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]} } @@ -168,13 +169,17 @@ def set_scale(*args, **kwargs): model_crystal_default = { "name": "PAiNN", "inputs": [ - {"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True}, + {"shape": (None,), "name": "node_attributes", "dtype": "int64", "ragged": True}, {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True}, {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True}, {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} ], - "input_embedding": {"input_dim": 95, "output_dim": 128}, + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 128}, + "has_equivariant_input": False, "equiv_initialize_kwargs": {"dim": 3, "method": "zeros"}, "bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5}, "pooling_args": {"pooling_method": "scatter_sum"}, @@ -185,13 +190,18 @@ def set_scale(*args, **kwargs): "depth": 3, "verbose": 10, "output_embedding": "graph", "output_to_tensor": True, + "output_scaling": None, "output_tensor_type": "padded", "output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]} } @update_model_kwargs(model_crystal_default) def make_crystal_model(inputs: list = None, + input_tensor_type: str = None, input_embedding: dict = None, + cast_disjoint_kwargs: dict = None, + has_equivariant_input: bool = None, + input_node_embedding: dict = None, equiv_initialize_kwargs: dict = None, bessel_basis: dict = None, depth: int = None, @@ -204,7 +214,9 @@ def make_crystal_model(inputs: list = None, verbose: int = None, output_embedding: str = None, output_to_tensor: bool = None, - output_mlp: dict = None + output_mlp: dict = None, + output_scaling: dict = None, + output_tensor_type: str = None ): r"""Make `PAiNN `_ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.PAiNN.model_crystal_default`. @@ -250,15 +262,16 @@ def make_crystal_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - batched_nodes, batched_x, batched_indices, total_nodes, total_edges = model_inputs[:5] - z, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges]) - x, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_nodes]) + disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=2+int(has_equivariant_input), has_crystal_input=2, + has_edges=False) - if len(model_inputs) > 5: - v, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([model_inputs[6], total_edges]) - else: + if not has_equivariant_input: + z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs v = None + else: + v, z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs # Wrapping disjoint model. out = model_disjoint_crystal( @@ -274,7 +287,7 @@ def make_crystal_model(inputs: list = None, scaler = get_scaler(output_scaling["name"])(**output_scaling) if scaler.extensive: # Node information must be numbers, or we need an additional input. - out = scaler([out, batched_nodes]) + out = scaler([out, z, batch_id_node]) else: out = scaler(out) diff --git a/training/hyper/hyper_mp_jdft2d.py b/training/hyper/hyper_mp_jdft2d.py index c9dcf33c..8dc35989 100644 --- a/training/hyper/hyper_mp_jdft2d.py +++ b/training/hyper/hyper_mp_jdft2d.py @@ -73,4 +73,73 @@ "kgcnn_version": "4.0.0" } }, + "PAiNN.make_crystal_model": { + "model": { + "module_name": "kgcnn.literature.PAiNN", + "class_name": "make_crystal_model", + "config": { + "name": "PAiNN", + "inputs": [ + {"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True}, + {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True}, + {"shape": [None, 2], "name": "range_indices", "dtype": "int64", "ragged": True}, + {'shape': (None, 3), 'name': "range_image", 'dtype': 'int64', 'ragged': True}, + {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} + ], + "input_tensor_type": "ragged", + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 128}, + "equiv_initialize_kwargs": {"dim": 3, "method": "eye"}, + "bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5}, + "pooling_args": {"pooling_method": "scatter_mean"}, + "conv_args": {"units": 128, "cutoff": None, "conv_pool": "scatter_sum"}, + "update_args": {"units": 128}, "depth": 2, "verbose": 10, + "equiv_normalization": False, "node_normalization": False, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]} + } + }, + "training": { + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "fit": { + "batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, + "callbacks": [ + {"class_name": "kgcnn>LinearWarmupLinearLearningRateScheduler", "config": { + "learning_rate_start": 1e-04, "learning_rate_stop": 1e-06, "epo_warmup": 25, "epo": 1000, + "verbose": 0} + } + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-04}}, # "clipnorm": 100.0, "clipvalue": 100.0} + "loss": "mean_absolute_error" + }, + "scaler": { + "class_name": "StandardLabelScaler", + "module_name": "kgcnn.data.transform.scaler.standard", + "config": {"with_std": True, "with_mean": True, "copy": True} + }, + "multi_target_indices": None + }, + "data": { + "dataset": { + "class_name": "MatProjectJdft2dDataset", + "module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset", + "config": {}, + "methods": [ + {"map_list": {"method": "set_range_periodic", "max_distance": 5.0}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices", "count_nodes": "node_number", + "total_nodes": "total_nodes"}}, + ] + }, + "data_unit": "meV/atom" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, } \ No newline at end of file diff --git a/training/results/MatProjectJdft2dDataset/Schnet_make_crystal_model/Schnet_MatProjectJdft2dDataset_score.yaml b/training/results/MatProjectJdft2dDataset/Schnet_make_crystal_model/Schnet_MatProjectJdft2dDataset_score.yaml index 0297dcd8..526aa1bc 100644 --- a/training/results/MatProjectJdft2dDataset/Schnet_make_crystal_model/Schnet_MatProjectJdft2dDataset_score.yaml +++ b/training/results/MatProjectJdft2dDataset/Schnet_make_crystal_model/Schnet_MatProjectJdft2dDataset_score.yaml @@ -1,11 +1,11 @@ OS: nt_win32 -backend: torch -cuda_available: 'True' +backend: tensorflow +cuda_available: 'False' data_unit: meV/atom -date_time: '2023-09-27 16:08:17' -device_id: '[0]' -device_memory: '[{''allocated'': 0.0, ''cached'': 0.1}]' -device_name: '[''NVIDIA GeForce GTX 1060 6GB'']' +date_time: '2023-11-14 15:39:20' +device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]' +device_memory: '[]' +device_name: '[{}]' epochs: - 800 - 800 @@ -21,11 +21,11 @@ learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 loss: -- 0.006081013474613428 -- 0.007487953174859285 -- 0.005631189793348312 -- 0.006110622547566891 -- 0.005237901117652655 +- 0.0064817434176802635 +- 0.007909221574664116 +- 0.005964066833257675 +- 0.006100769620388746 +- 0.005272249691188335 max_learning_rate: - 0.0005000000237487257 - 0.0005000000237487257 @@ -33,41 +33,41 @@ max_learning_rate: - 0.0005000000237487257 - 0.0005000000237487257 max_loss: -- 0.4087394177913666 -- 0.4102928638458252 -- 0.4046685993671417 -- 0.4130890667438507 -- 0.4167194962501526 +- 0.4114467203617096 +- 0.4189068377017975 +- 0.40715721249580383 +- 0.3852750062942505 +- 0.4248596131801605 max_scaled_mean_absolute_error: -- 59.10055923461914 -- 56.630287170410156 -- 55.20092010498047 -- 53.308082580566406 -- 49.63638687133789 +- 59.73918151855469 +- 58.42951965332031 +- 55.47541809082031 +- 49.77934265136719 +- 50.780094146728516 max_scaled_root_mean_squared_error: -- 147.88436889648438 -- 141.33148193359375 -- 140.4429473876953 -- 131.02452087402344 -- 187.04994201660156 +- 148.1173858642578 +- 142.6780548095703 +- 137.87705993652344 +- 130.23377990722656 +- 121.23198699951172 max_val_loss: -- 0.2501205801963806 -- 0.3599293828010559 -- 0.3686738908290863 -- 0.4995414614677429 -- 0.5789994597434998 +- 0.24786990880966187 +- 0.3615275025367737 +- 0.40639615058898926 +- 0.497827410697937 +- 0.5765356421470642 max_val_scaled_mean_absolute_error: -- 36.26357650756836 -- 50.10081100463867 -- 50.26176452636719 -- 64.30477142333984 -- 69.21404266357422 +- 35.93727111816406 +- 50.37811279296875 +- 55.472347259521484 +- 64.05498504638672 +- 68.9144287109375 max_val_scaled_root_mean_squared_error: -- 85.91874694824219 -- 112.68403625488281 -- 127.67237091064453 -- 153.7728271484375 -- 185.92977905273438 +- 71.01972198486328 +- 112.99500274658203 +- 140.3220672607422 +- 168.19027709960938 +- 186.11448669433594 min_learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 @@ -75,80 +75,80 @@ min_learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 min_loss: -- 0.006081013474613428 -- 0.007487953174859285 -- 0.005631189793348312 -- 0.006050079129636288 -- 0.005139544140547514 +- 0.0061943670734763145 +- 0.007811109069734812 +- 0.005964066833257675 +- 0.006100769620388746 +- 0.005272249691188335 min_scaled_mean_absolute_error: -- 0.8825702667236328 -- 1.046212077140808 -- 0.7719058990478516 -- 0.7827093005180359 -- 0.6177893280982971 +- 0.904198408126831 +- 1.0941426753997803 +- 0.8168217539787292 +- 0.7751782536506653 +- 0.6322131752967834 min_scaled_root_mean_squared_error: -- 6.864800930023193 -- 7.26033353805542 -- 5.673646450042725 -- 5.397184371948242 -- 3.852328300476074 +- 7.041600704193115 +- 7.389231204986572 +- 5.576786041259766 +- 5.117680072784424 +- 3.77896785736084 min_val_loss: -- 0.1605568826198578 -- 0.2237636148929596 -- 0.3104589283466339 -- 0.3806076645851135 -- 0.4474336802959442 +- 0.1757761389017105 +- 0.2408549189567566 +- 0.3171415627002716 +- 0.38893163204193115 +- 0.4585150182247162 min_val_scaled_mean_absolute_error: -- 23.278240203857422 -- 31.064952850341797 -- 42.32065200805664 -- 48.98826217651367 -- 53.551326751708984 +- 25.48479461669922 +- 33.454715728759766 +- 43.22237014770508 +- 50.03043746948242 +- 54.83570098876953 min_val_scaled_root_mean_squared_error: -- 49.61679458618164 -- 86.47533416748047 -- 115.76853942871094 -- 137.28257751464844 -- 168.3240966796875 +- 55.98909378051758 +- 88.43077850341797 +- 115.5170669555664 +- 137.11032104492188 +- 165.78585815429688 model_class: make_crystal_model model_name: Schnet model_version: '2023-09-07' multi_target_indices: null number_histories: 5 scaled_mean_absolute_error: -- 0.8825702667236328 -- 1.046212077140808 -- 0.7719058990478516 -- 0.7898567318916321 -- 0.6285557150840759 +- 0.9458503127098083 +- 1.108201026916504 +- 0.8168217539787292 +- 0.7751782536506653 +- 0.6322131752967834 scaled_root_mean_squared_error: -- 6.910317420959473 -- 7.314899444580078 -- 5.6926188468933105 -- 5.397184371948242 -- 3.8545916080474854 +- 7.128204822540283 +- 7.412099838256836 +- 5.626010894775391 +- 5.117680072784424 +- 3.7827579975128174 seed: 42 time_list: -- '0:11:38.156039' -- '0:11:56.957036' -- '0:11:53.190291' -- '0:11:48.403422' -- '0:11:51.930701' +- '0:11:53.572924' +- '0:12:32.080251' +- '0:15:06.213439' +- '0:24:16.044001' +- '0:36:55.076988' val_loss: -- 0.1888023018836975 -- 0.3158849775791168 -- 0.3324388265609741 -- 0.49379169940948486 -- 0.46170324087142944 +- 0.19371935725212097 +- 0.26765677332878113 +- 0.330279141664505 +- 0.4501018822193146 +- 0.48370033502578735 val_scaled_mean_absolute_error: -- 27.37338638305664 -- 44.02018356323242 -- 45.268585205078125 -- 63.57335662841797 -- 55.249542236328125 +- 28.086280822753906 +- 37.340938568115234 +- 44.974525451660156 +- 58.00413131713867 +- 57.80003356933594 val_scaled_root_mean_squared_error: -- 58.30507278442383 -- 106.4539794921875 -- 118.05570983886719 -- 152.6748809814453 -- 169.71142578125 +- 57.09575653076172 +- 90.59886169433594 +- 119.77587127685547 +- 141.45004272460938 +- 169.5246124267578 diff --git a/training/results/README.md b/training/results/README.md index b29cc6a5..76377c6f 100644 --- a/training/results/README.md +++ b/training/results/README.md @@ -138,7 +138,7 @@ Materials Project dataset from Matbench with 636 crystal structures and their co | model | kgcnn | epochs | MAE [meV/atom] | RMSE [meV/atom] | |:--------------------------|:--------|---------:|:-------------------------|:--------------------------| -| Schnet.make_crystal_model | 4.0.0 | 800 | **47.0970 ± 12.1636** | **121.0402 ± 38.7995** | +| Schnet.make_crystal_model | 4.0.0 | 800 | **45.2412 ± 11.6395** | **115.6890 ± 39.0929** | #### MatProjectLogGVRHDataset diff --git a/training/train_graph.py b/training/train_graph.py index d78f1bd5..9810ad81 100644 --- a/training/train_graph.py +++ b/training/train_graph.py @@ -23,7 +23,7 @@ parser = argparse.ArgumentParser(description='Train a GNN on a graph regression or classification task.') parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).", default="hyper/hyper_mp_jdft2d.py") -parser.add_argument("--category", required=False, help="Graph model to train.", default="Schnet.make_crystal_model") +parser.add_argument("--category", required=False, help="Graph model to train.", default="PAiNN.make_crystal_model") parser.add_argument("--model", required=False, help="Graph model to train.", default=None) parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None) parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)