Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 17, 2023
1 parent e252412 commit 1361483
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 829 deletions.
2 changes: 1 addition & 1 deletion kgcnn/backend/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def decompose_ragged_tensor(x, batch_dtype="int64"):


def norm(x, ord='fro', axis=None, keepdims=False):
return tf.norm(x, ord=ord, dim=axis, keepdims=keepdims)
return tf.norm(x, ord=ord, axis=axis, keepdims=keepdims)
2 changes: 1 addition & 1 deletion kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def call(self, inputs: list, **kwargs):
if self.static_output_shape is not None:
target_shape = (ops.shape(attr_len)[0], self.static_output_shape[0])
else:
target_shape = (ops.shape(attr_len)[0], ops.amax(attr_len))
target_shape = (ops.shape(attr_len)[0], ops.cast(ops.amax(attr_len), dtype="int32"))

if not self.padded_disjoint:
if attr_id is None:
Expand Down
7 changes: 7 additions & 0 deletions kgcnn/literature/GNNExplain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._model import GNNExplainerOptimizer, GNNInterface, GNNExplainer

__all__ = [
"GNNExplainerOptimizer",
"GNNInterface",
"GNNExplainer"
]
40 changes: 24 additions & 16 deletions kgcnn/literature/GNNExplain/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,13 @@ def explain(self, graph_instance, output_to_explain=None, inspection=False, **kw
gnnx_optimizer = GNNExplainerOptimizer(
self.gnn, graph_instance, **self.gnnexplaineroptimizer_options)
self.gnnx_optimizer = gnnx_optimizer

if output_to_explain is not None:
gnnx_optimizer.output_to_explain = output_to_explain

gnnx_optimizer.compile(**self.compile_options)
gnnx_optimizer.fit(graph_instance, **fit_options)

gnnx_optimizer.fit(x=graph_instance, y=gnnx_optimizer.output_to_explain, **fit_options)

# Read out information from inspection_callback
if inspection:
Expand Down Expand Up @@ -265,21 +268,20 @@ def __init__(self, graph_instance):
self.node_mask_loss = []

def on_epoch_begin(self, epoch, logs=None):
masked = self.model.call(self.graph_instance).numpy()[0]
masked = ops.convert_to_numpy(self.model.call(self.graph_instance))[0]
self.predictions.append(masked)

def on_epoch_end(self, epoch, logs=None):
"""After epoch."""
index = 0
losses_list = [loss_iter.numpy() for loss_iter in self.model.losses]
if self.model.edge_mask_loss_weight > 0:
self.edge_mask_loss.append(losses_list[index])
index = index + 1
self.edge_mask_loss.append(ops.convert_to_numpy(self.model._metric_edge_tracker.result()))
self.model._metric_edge_tracker.reset_state()
if self.model.feature_mask_loss_weight > 0:
self.feature_mask_loss.append(losses_list[index])
index = index + 1
self.feature_mask_loss.append(ops.convert_to_numpy(self.model._metric_feature_tracker.result()))
self.model._metric_feature_tracker.reset_state()
if self.model.node_mask_loss_weight > 0:
self.node_mask_loss.append(losses_list[index])
self.node_mask_loss.append(ops.convert_to_numpy(self.model._metric_node_tracker.result()))
self.model._metric_node_tracker.reset_state()
self.total_loss.append(logs['loss'])


Expand Down Expand Up @@ -320,6 +322,9 @@ def __init__(self, gnn_model, graph_instance,
"""
super(GNNExplainerOptimizer, self).__init__(**kwargs)
self.gnn_model = gnn_model
self._metric_node_tracker = ks.metrics.Mean(name="mask_loss")
self._metric_edge_tracker = ks.metrics.Mean(name="mask_loss")
self._metric_feature_tracker = ks.metrics.Mean(name="mask_loss")
self._edge_mask_dim = self.gnn_model.get_number_of_edges(
graph_instance)
self._feature_mask_dim = self.gnn_model.get_number_of_node_features(
Expand Down Expand Up @@ -368,7 +373,7 @@ def call(self, graph_input, training: bool = False, **kwargs):
training (bool): If training mode. Default is False.
Returns:
tf.tensor: Masked prediction of GNN model.
Tensor: Masked prediction of GNN model.
"""
edge_mask = self.get_mask("edge")
feature_mask = self.get_mask("feature")
Expand All @@ -377,16 +382,19 @@ def call(self, graph_input, training: bool = False, **kwargs):

# edge_mask loss
if self.edge_mask_loss_weight > 0:
self.add_loss(lambda: norm(ops.sigmoid(
self.edge_mask), ord=self.edge_mask_norm_ord) * self.edge_mask_loss_weight)
loss = norm(ops.sigmoid(self.edge_mask), ord=self.edge_mask_norm_ord) * self.edge_mask_loss_weight
self.add_loss(loss)
self._metric_edge_tracker.update_state([loss])
# feature_mask loss
if self.feature_mask_loss_weight > 0:
self.add_loss(lambda: norm(ops.sigmoid(
self.feature_mask), ord=self.feature_mask_norm_ord) * self.feature_mask_loss_weight)
loss = norm(ops.sigmoid(self.feature_mask), ord=self.feature_mask_norm_ord) * self.feature_mask_loss_weight
self.add_loss(loss)
self._metric_feature_tracker.update_state([loss])
# node_mask loss
if self.node_mask_loss_weight > 0:
self.add_loss(lambda: norm(ops.sigmoid(
self.node_mask), ord=self.node_mask_norm_ord) * self.node_mask_loss_weight)
loss = norm(ops.sigmoid(self.node_mask), ord=self.node_mask_norm_ord) * self.node_mask_loss_weight
self.add_loss(loss)
self._metric_node_tracker.update_state([loss])

return y_pred

Expand Down
Loading

0 comments on commit 1361483

Please sign in to comment.