diff --git a/kgcnn/losses/losses.py b/kgcnn/losses/losses.py index cd006763..9171f801 100644 --- a/kgcnn/losses/losses.py +++ b/kgcnn/losses/losses.py @@ -4,6 +4,7 @@ from keras.losses import mean_absolute_error import keras.saving from kgcnn.ops.core import decompose_ragged_tensor +# from kgcnn.ops.scatter import scatter_reduce_mean @ks.saving.register_keras_serializable(package='kgcnn', name='MeanAbsoluteError') @@ -61,25 +62,27 @@ def get_config(self): @ks.saving.register_keras_serializable(package='kgcnn', name='DisjointForceMeanAbsoluteError') class DisjointForceMeanAbsoluteError(Loss): - """This is dummy class. Not working at the moment as intended. + """This is an experimental class for force loss of disjoint output.""" - We need to pass the node ids here somehow. - """ + # Ideally we want to pass batch IDs and energy shape to the loss to do scatter mean. + # However, passing as attribute similar to mask, has not been working yet. def __init__(self, reduction="sum_over_batch_size", name="force_mean_absolute_error", - squeeze_states: bool = True, find_padded_atoms: bool = True, dtype=None): + squeeze_states: bool = True, padded_disjoint: bool = False, dtype=None): super(DisjointForceMeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype) self.squeeze_states = squeeze_states - self.find_padded_atoms = find_padded_atoms + self.padded_disjoint = padded_disjoint def call(self, y_true, y_pred): # Shape: ([N], 3, S) - if self.find_padded_atoms: - check_nonzero = ops.logical_not( - ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=1)) - y_pred = y_pred * ops.cast(ops.expand_dims(check_nonzero, axis=1), dtype=y_pred.dtype) - row_count = ops.sum(ops.cast(check_nonzero, dtype="int32"), axis=0) + + if self.padded_disjoint: + # Mask Shape: ([N, S]) + mask = ops.logical_not(ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=1)) + y_pred = y_pred * ops.cast(ops.expand_dims(mask, axis=1), dtype=y_pred.dtype) + row_count = ops.sum(ops.cast(mask, dtype="int32"), axis=0) row_count = ops.where(row_count < 1, 1, row_count) # Prevent divide by 0. + # roq count shape ([S]) norm = 1 / ops.cast(row_count, dtype=y_true.dtype) else: norm = 1/ops.shape(y_true)[0] @@ -87,6 +90,7 @@ def call(self, y_true, y_pred): diff = ops.abs(y_true-y_pred) out = ops.mean(diff, axis=1) out = ops.sum(out, axis=0)*norm + if not self.squeeze_states: out = ops.mean(out, axis=-1) diff --git a/kgcnn/models/force.py b/kgcnn/models/force.py index e4f54164..f67b5870 100644 --- a/kgcnn/models/force.py +++ b/kgcnn/models/force.py @@ -1,5 +1,6 @@ import keras as ks import keras.saving +from keras import ops from typing import Union from kgcnn.models.utils import get_model_class from keras.saving import deserialize_keras_object, serialize_keras_object @@ -87,12 +88,10 @@ def __init__(self, # Additional parameters of io and behavior of this class. self.ragged_validate = ragged_validate self.coordinate_input = coordinate_input - # self.output_to_tensor = output_to_tensor self.output_squeeze_states = output_squeeze_states self.is_physical_force = is_physical_force self.nested_model_config = nested_model_config self._force_outputs = outputs - # self.use_batch_jacobian = use_batch_jacobian self.output_as_dict = output_as_dict if isinstance(output_as_dict, bool):