Skip to content

Commit

Permalink
Update force model and docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 3, 2024
1 parent 18fa413 commit 7000320
Showing 1 changed file with 111 additions and 48 deletions.
159 changes: 111 additions & 48 deletions kgcnn/models/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,29 @@
from keras.backend import backend


# In keras 3.0.0 there is no `ops.gradient()` function yet.
# Backend specific gradient implementation in the following.
if backend() == "tensorflow":
import tensorflow as tf
elif backend() == "torch":
import torch
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())


@ks.saving.register_keras_serializable(package='kgcnn', name='EnergyForceModel')
class EnergyForceModel(ks.models.Model):
r"""Outer model to wrap a normal invariant GNN to predict forces from energy predictions via partial derivatives.
The Force :math:`\vec{F_i}` on Atom :math:`i` is given by
.. math::
\vec{F_i} = - \vec{\nabla}_i E_{\text{total}}
Note that the model simply returns the tensor type of the coordinate input for forces. No casting is done
by this class. This means that the model returns a ragged, disjoint or padded tensor depending on the tensor
type of the coordinates.
"""

def __init__(self,
inputs: Union[dict, list] = None,
Expand All @@ -19,11 +40,29 @@ def __init__(self,
output_squeeze_states: bool = False,
nested_model_config: bool = True,
is_physical_force: bool = True,
use_batch_jacobian: bool = True,
use_batch_jacobian: bool = None,
name: str = None,
outputs: Union[dict, list] = None
):
"""Initialize Force model with an energy model.
Args:
inputs (list): List of inputs as dictionary kwargs of keras input layers.
model_energy (ks.models.Model, dict): Keras model os deserialization dictionary for a keras model.
coordinate_input (int): Position of the coordinate input.
output_as_dict (bool, tuple): Names for the output if a dictionary should be returned. Or simply a bool
which will use the names "energy" and "force".
ragged_validate (bool): Whether to validate ragged or jagged tensors.
output_to_tensor: Deprecated.
output_squeeze_states (bool): Whether to squeeze state/energy dimension for forces
in case of a single energy value.
nested_model_config (bool): Whether there is a config for the energy model.
is_physical_force (bool): Whether to return the physical force, e.g. the negative gradient of the energy.
use_batch_jacobian: Deprecated.
name (str): Name of the model.
outputs: List of outputs as dictionary kwargs similar to inputs. Not used by the model but can be passed
for external use.
"""
super().__init__()
if model_energy is None:
raise ValueError("Require valid model in `model_energy` for force prediction.")
Expand All @@ -40,69 +79,92 @@ def __init__(self,
self.energy_model_class = get_model_class(model_energy["module_name"], model_energy["class_name"])
self.energy_model = self.energy_model_class(**model_energy["config"])
else:
raise TypeError("Input `model_energy` must be dict or `ks.models.Model` .")
raise TypeError("Input `model_energy` must be dict or `ks.models.Model` . Can not deserialize model.")

# Additional parameters of io and behavior of this class.
self.ragged_validate = ragged_validate
self.coordinate_input = coordinate_input
self.output_as_dict = output_as_dict
self.output_to_tensor = output_to_tensor
# self.output_to_tensor = output_to_tensor
self.output_squeeze_states = output_squeeze_states
self.is_physical_force = is_physical_force
self.nested_model_config = nested_model_config
self.use_batch_jacobian = use_batch_jacobian
self._force_outputs = outputs
# self.use_batch_jacobian = use_batch_jacobian

self.output_as_dict = output_as_dict
if isinstance(output_as_dict, bool):
self.output_as_dict_use = output_as_dict
self.output_as_dict_names = ("energy", "force")
elif isinstance(output_as_dict, (list, tuple)):
self.output_as_dict_use = True
self.output_as_dict_names = (output_as_dict[0], output_as_dict[1])
else:
self.output_as_dict_use = False

# We can try to infer the model inputs from energy model, if not given explicit.
self._inputs_to_force_model = inputs
if self._inputs_to_force_model is None:
if self.nested_model_config and isinstance(model_energy, dict):
self._inputs_to_force_model = model_energy["config"]["inputs"]

if backend() == "tensorflow":
self._call_grad_backend = self._call_grad_tf
elif backend() == "torch":
self._call_grad_backend = self._call_grad_torch
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

def build(self, input_shape):
self.energy_model.build(input_shape)
self.built = True

def call(self, inputs, training=False, **kwargs):
def _call_grad_tf(self, inputs, training=False, **kwargs):

x_in = inputs[self.coordinate_input]
with tf.GradientTape(persistent=True) as tape:
if isinstance(x_in, tf.RaggedTensor):
x, splits = x_in.values, x_in.row_splits
else:
x, splits = x_in, None
tape.watch(x)
eng = self.energy_model(inputs, training=training, **kwargs)
eng_sum = tf.reduce_sum(eng, axis=0, keepdims=False)
e_grad = [eng_sum[i] for i in range(eng_sum.shape[-1])]
e_grad = [tf.expand_dims(tape.gradient(e_i, x), axis=-1) for e_i in e_grad]
e_grad = tf.concat(e_grad, axis=-1)

if self.output_squeeze_states:
e_grad = tf.squeeze(e_grad, axis=-1)

if isinstance(x_in, tf.RaggedTensor):
e_grad = tf.RaggedTensor.from_row_splits(e_grad, splits, validate=self.ragged_validate)

return eng, e_grad

def _call_grad_torch(self, inputs, training=False, **kwargs):

x = inputs[self.coordinate_input]
with torch.enable_grad():
x.requires_grad = True
eng = self.energy_model.call(inputs, training=training, **kwargs)
eng_sum = eng.sum(dim=0)
e_grad = torch.cat([
torch.unsqueeze(torch.autograd.grad(eng_sum[i], x, create_graph=True)[0], dim=-1) for i in
range(eng.shape[-1])], dim=-1)

if backend() == "tensorflow":
import tensorflow as tf
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
eng = self.energy_model(inputs, training=training, **kwargs)
eng_sum = tf.reduce_sum(eng, axis=0, keepdims=False)
e_grad = [eng_sum[i] for i in range(eng_sum.shape[-1])]
e_grad = [tf.expand_dims(tape.gradient(e_i, x), axis=-1) for e_i in e_grad]
e_grad = tf.concat(e_grad, axis=-1)
if self.is_physical_force:
e_grad = -e_grad
if self.output_squeeze_states:
e_grad = tf.squeeze(e_grad, axis=-1)
elif backend() == "torch":
import torch
with torch.enable_grad():
x.requires_grad = True
eng = self.energy_model(inputs, training=training, **kwargs)
eng_sum = eng.sum(dim=0)
e_grad = torch.cat([
torch.unsqueeze(torch.autograd.grad(eng_sum[i], x, create_graph=True)[0], dim=-1) for i in
range(eng.shape[-1])], dim=-1)
if self.is_physical_force:
e_grad = -e_grad
if self.output_squeeze_states:
e_grad = torch.squeeze(e_grad, dim=-1)
elif backend() == "jax":
raise NotImplementedError()
from jax import grad
import jax.numpy as jnp
e_grad = grad(self.energy_model, argnums=self.coordinate_input)(inputs)
if self.is_physical_force:
e_grad = -e_grad
if self.output_squeeze_states:
e_grad = jnp.squeeze(e_grad, axis=-1)
else:
raise NotImplementedError("Gradient not supported for backend '%s'." % backend())
if self.output_squeeze_states:
e_grad = torch.squeeze(e_grad, dim=-1)
return eng, e_grad

def call(self, inputs, training=False, **kwargs):

eng, e_grad = self._call_grad_backend(inputs, training=training, **kwargs)

if self.is_physical_force:
e_grad = -e_grad

if self.output_as_dict:
return {"energy": eng, "force": e_grad}
if self.output_as_dict_use:
return {self.output_as_dict_names[0]: eng, self.output_as_dict_names[1]: e_grad}
else:
return eng, e_grad

Expand All @@ -121,10 +183,11 @@ def get_config(self):
"coordinate_input": self.coordinate_input,
"output_as_dict": self.output_as_dict,
"ragged_validate": self.ragged_validate,
"output_to_tensor": self.output_to_tensor,
# "output_to_tensor": self.output_to_tensor,
"output_squeeze_states": self.output_squeeze_states,
"nested_model_config": self.nested_model_config,
"use_batch_jacobian": self.use_batch_jacobian,
"inputs": self._inputs_to_force_model
# "use_batch_jacobian": self.use_batch_jacobian,
"inputs": self._inputs_to_force_model,
"outputs": self._force_outputs
})
return conf

0 comments on commit 7000320

Please sign in to comment.