Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 21, 2023
1 parent 7fd9451 commit 62114b6
Show file tree
Hide file tree
Showing 6 changed files with 865 additions and 1,069 deletions.
311 changes: 275 additions & 36 deletions docs/source/literature.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion kgcnn/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def build(self, input_shape):
"""Build layer."""
assert len(input_shape) == 4
ref_shape, attr_shape, weights_shape, index_shape = [list(x) for x in input_shape]
self.to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]])
self._to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]])
self.built = True

def call(self, inputs, **kwargs):
Expand Down
24 changes: 22 additions & 2 deletions kgcnn/literature/GNNExplain/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def __init__(self, gnn, gnnexplaineroptimizer_options=None,
self.gnnx_optimizer = None
self.graph_instance = None
self.gnnexplaineroptimizer_options = gnnexplaineroptimizer_options
# We need to save options as serialized version to recreate the optimizer on multiple explain calls.
if "optimizer" in compile_options:
if isinstance(compile_options["optimizer"], ks.optimizers.Optimizer):
compile_options["optimizer"] = ks.saving.serialize_keras_object(compile_options["optimizer"])
self.compile_options = compile_options
self.fit_options = fit_options

Expand Down Expand Up @@ -196,7 +200,10 @@ def explain(self, graph_instance, output_to_explain=None, inspection=False, **kw
self.gnnx_optimizer = gnnx_optimizer

if output_to_explain is not None:
gnnx_optimizer.output_to_explain = output_to_explain
if gnnx_optimizer._output_to_explain_as_variable:
gnnx_optimizer.output_to_explain.assign(output_to_explain)
else:
gnnx_optimizer.output_to_explain = output_to_explain

gnnx_optimizer.compile(**self.compile_options)

Expand Down Expand Up @@ -290,6 +297,8 @@ class GNNExplainerOptimizer(ks.Model):
which then can be used to explain decisions by GNNs.
"""

_output_to_explain_as_variable = False

def __init__(self, gnn_model, graph_instance,
edge_mask_loss_weight=1e-4,
edge_mask_norm_ord=1,
Expand Down Expand Up @@ -355,7 +364,18 @@ def __init__(self, gnn_model, graph_instance,
dtype=self.dtype,
trainable=True
)
self.output_to_explain = gnn_model.predict(graph_instance)
output_to_explain = gnn_model.predict(graph_instance)
if self._output_to_explain_as_variable:
self.output_to_explain = self.add_weight(
name='output_to_explain',
shape=output_to_explain.shape,
initializer=ks.initializers.Constant(0.),
dtype=output_to_explain.dtype,
trainable=False
)
self.output_to_explain.assign(output_to_explain)
else:
self.output_to_explain = output_to_explain

# Configuration Parameters
self.edge_mask_loss_weight = edge_mask_loss_weight
Expand Down

Large diffs are not rendered by default.

Loading

0 comments on commit 62114b6

Please sign in to comment.