Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 14, 2023
1 parent 3ab54e9 commit 15af06e
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 120 deletions.
8 changes: 8 additions & 0 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,14 @@ def compute_output_shape(self, input_shape):
out_gn, out_id_n, out_size_n = out_n[:1], out_n[:1], input_shape[:1]
return out_n, out_gn, out_id_n, out_size_n

def compute_output_spec(self, inputs_spec):
"""Compute output spec as possible."""
output_shape = self.compute_output_shape(inputs_spec.shape)
dtype_batch = self.dtype_batch
output_dtypes = [inputs_spec[0].dtype, dtype_batch, dtype_batch, dtype_batch]
output_spec = [ks.KerasTensor(s, dtype=d) for s, d in zip(output_shape, output_dtypes)]
return output_spec

def build(self, input_shape):
self.built = True

Expand Down
6 changes: 3 additions & 3 deletions kgcnn/literature/PAiNN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ._make import make_model, model_default
# from ._make import make_crystal_model, model_crystal_default
from ._make import make_crystal_model, model_crystal_default


__all__ = [
"make_model",
"model_default",
# "make_crystal_model",
# "model_crystal_default"
"make_crystal_model",
"model_crystal_default"
]
6 changes: 4 additions & 2 deletions kgcnn/literature/PAiNN/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def get_config(self):
config_dense = self.lay_dense1.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]:
config.update({x: config_dense[x]})
if x in config_dense:
config.update({x: config_dense[x]})
return config


Expand Down Expand Up @@ -220,7 +221,8 @@ def get_config(self):
config_dense = self.lay_dense1.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]:
config.update({x: config_dense[x]})
if x in config_dense:
config.update({x: config_dense[x]})
return config


Expand Down
37 changes: 25 additions & 12 deletions kgcnn/literature/PAiNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
"depth": 3,
"verbose": 10,
"output_embedding": "graph",
"output_to_tensor": True,
"output_to_tensor": True, "output_tensor_type": "padded",
"output_scaling": None,
"output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]}
}

Expand Down Expand Up @@ -168,13 +169,17 @@ def set_scale(*args, **kwargs):
model_crystal_default = {
"name": "PAiNN",
"inputs": [
{"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True},
{"shape": (None,), "name": "node_attributes", "dtype": "int64", "ragged": True},
{"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True},
{"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True},
{'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True},
{'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False}
],
"input_embedding": {"input_dim": 95, "output_dim": 128},
"input_tensor_type": "padded",
"input_embedding": None, # deprecated
"cast_disjoint_kwargs": {},
"input_node_embedding": {"input_dim": 95, "output_dim": 128},
"has_equivariant_input": False,
"equiv_initialize_kwargs": {"dim": 3, "method": "zeros"},
"bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5},
"pooling_args": {"pooling_method": "scatter_sum"},
Expand All @@ -185,13 +190,18 @@ def set_scale(*args, **kwargs):
"depth": 3,
"verbose": 10,
"output_embedding": "graph", "output_to_tensor": True,
"output_scaling": None, "output_tensor_type": "padded",
"output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]}
}


@update_model_kwargs(model_crystal_default)
def make_crystal_model(inputs: list = None,
input_tensor_type: str = None,
input_embedding: dict = None,
cast_disjoint_kwargs: dict = None,
has_equivariant_input: bool = None,
input_node_embedding: dict = None,
equiv_initialize_kwargs: dict = None,
bessel_basis: dict = None,
depth: int = None,
Expand All @@ -204,7 +214,9 @@ def make_crystal_model(inputs: list = None,
verbose: int = None,
output_embedding: str = None,
output_to_tensor: bool = None,
output_mlp: dict = None
output_mlp: dict = None,
output_scaling: dict = None,
output_tensor_type: str = None
):
r"""Make `PAiNN <https://arxiv.org/pdf/2102.03150.pdf>`_ graph network via functional API.
Default parameters can be found in :obj:`kgcnn.literature.PAiNN.model_crystal_default`.
Expand Down Expand Up @@ -250,15 +262,16 @@ def make_crystal_model(inputs: list = None,
# Make input
model_inputs = [Input(**x) for x in inputs]

batched_nodes, batched_x, batched_indices, total_nodes, total_edges = model_inputs[:5]
z, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges])
x, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_nodes])
disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=2+int(has_equivariant_input), has_crystal_input=2,
has_edges=False)

if len(model_inputs) > 5:
v, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([model_inputs[6], total_edges])
else:
if not has_equivariant_input:
z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs
v = None
else:
v, z, x, edi, edge_image, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs

# Wrapping disjoint model.
out = model_disjoint_crystal(
Expand All @@ -274,7 +287,7 @@ def make_crystal_model(inputs: list = None,
scaler = get_scaler(output_scaling["name"])(**output_scaling)
if scaler.extensive:
# Node information must be numbers, or we need an additional input.
out = scaler([out, batched_nodes])
out = scaler([out, z, batch_id_node])
else:
out = scaler(out)

Expand Down
69 changes: 69 additions & 0 deletions training/hyper/hyper_mp_jdft2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,73 @@
"kgcnn_version": "4.0.0"
}
},
"PAiNN.make_crystal_model": {
"model": {
"module_name": "kgcnn.literature.PAiNN",
"class_name": "make_crystal_model",
"config": {
"name": "PAiNN",
"inputs": [
{"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True},
{"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True},
{"shape": [None, 2], "name": "range_indices", "dtype": "int64", "ragged": True},
{'shape': (None, 3), 'name': "range_image", 'dtype': 'int64', 'ragged': True},
{'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False}
],
"input_tensor_type": "ragged",
"cast_disjoint_kwargs": {},
"input_node_embedding": {"input_dim": 95, "output_dim": 128},
"equiv_initialize_kwargs": {"dim": 3, "method": "eye"},
"bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5},
"pooling_args": {"pooling_method": "scatter_mean"},
"conv_args": {"units": 128, "cutoff": None, "conv_pool": "scatter_sum"},
"update_args": {"units": 128}, "depth": 2, "verbose": 10,
"equiv_normalization": False, "node_normalization": False,
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True], "units": [128, 1], "activation": ["swish", "linear"]}
}
},
"training": {
"cross_validation": {"class_name": "KFold",
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
"fit": {
"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2,
"callbacks": [
{"class_name": "kgcnn>LinearWarmupLinearLearningRateScheduler", "config": {
"learning_rate_start": 1e-04, "learning_rate_stop": 1e-06, "epo_warmup": 25, "epo": 1000,
"verbose": 0}
}
]
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-04}}, # "clipnorm": 100.0, "clipvalue": 100.0}
"loss": "mean_absolute_error"
},
"scaler": {
"class_name": "StandardLabelScaler",
"module_name": "kgcnn.data.transform.scaler.standard",
"config": {"with_std": True, "with_mean": True, "copy": True}
},
"multi_target_indices": None
},
"data": {
"dataset": {
"class_name": "MatProjectJdft2dDataset",
"module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset",
"config": {},
"methods": [
{"map_list": {"method": "set_range_periodic", "max_distance": 5.0}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "node_number",
"total_nodes": "total_nodes"}},
]
},
"data_unit": "meV/atom"
},
"info": {
"postfix": "",
"postfix_file": "",
"kgcnn_version": "4.0.0"
}
},
}
Loading

0 comments on commit 15af06e

Please sign in to comment.