-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy path_model.py
436 lines (374 loc) · 20.4 KB
/
_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
import keras as ks
from keras import ops
from kgcnn.ops.core import norm
class GNNInterface:
"""An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable.
This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass.
The implementation of this class could be a wrapper around an existing Tensorflow/Keras GNN.
The output of the methods `predict` and `masked_predict` should be of same dimension and the output to be explained.
"""
def predict(self, gnn_input, **kwargs):
"""Returns the prediction for the `gnn_input`.
Args:
gnn_input: The input graph to which a prediction should be made by the GNN.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def masked_predict(self, gnn_input, edge_mask, feature_mask, node_mask, **kwargs):
"""Returns the prediction for the `gnn_input` when it is masked by the three given masks.
Args:
gnn_input: The input graph to which should be masked before a prediction should be made by the GNN.
edge_mask: A `Tensor` of shape `[get_number_of_edges(self, gnn_input), 1]`,
which should mask the edges of the input graph.
feature_mask: A `Tensor` of shape `[get_number_of_node_features(self, gnn_input), 1]`,
which should mask the node features in the input graph.
node_mask: A `Tensor` of shape `[get_number_of_nodes(self, gnn_input), 1]`,
which should mask the node features in the input graph.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def get_number_of_nodes(self, gnn_input):
"""Returns the number of nodes in the `gnn_input` graph.
Args:
gnn_input: The input graph to which this function returns the number of nodes in.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def get_number_of_edges(self, gnn_input):
"""Returns the number of edges in the `gnn_input` graph.
Args:
gnn_input: The input graph to which this function returns the number of edges in.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def get_number_of_node_features(self, gnn_input):
"""Returns the number of node features to the corresponding `gnn_input`.
Args:
gnn_input: The input graph to which this function returns the number of node features in.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def get_explanation(self, gnn_input, edge_mask, feature_mask, node_mask, **kwargs):
"""Takes the graph input and the masks learned by the GNNExplainer and combines them to some sort of explanation
The form of explanation could e.g. consist of a networkx graph,
which has mask values as labels to nodes/edge and a dict for the feature explanation values.
Args:
gnn_input: The input graph to which should the masks were found by the GNNExplainer.
edge_mask: A `Tensor` of shape `[get_number_of_edges(self, gnn_input), 1]`,
which was found by the GNNExplainer.
feature_mask: A `Tensor` of shape `[get_number_of_node_features(self, gnn_input), 1]`,
which was found by the GNNExplainer.
node_mask: A `Tensor` of shape `[get_number_of_nodes(self, gnn_input), 1]`,
which was found by the GNNExplainer.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
def present_explanation(self, explanation, **kwargs):
"""Takes an explanation, which was generated by `get_explanation` and presents it to the user in a suitable way.
The presentation of an explanation largely depends on the data domain and targeted user group.
Examples for presentations:
* A visualization of the most relevant subgraph(s) to the decision
* A visualization of the whole graph with highlighted parts
* Bar diagrams for feature explanations
* ...
Args:
explanation: An explanation for the GNN decision,
which is of the form the `get_explanation` method returns an explanation.
Raises:
NotImplementedError: This is just an interface class, to indicate which methods should be implemented.
Implement this method in a subclass.
"""
raise NotImplementedError(
"Implement this method in a specific subclass.")
class GNNExplainer:
"""`GNNExplainer` explains the decisions of a GNN, which implements `GNNInterface`.
See Ying et al. (https://arxiv.org/abs/1903.03894) for details on how such an explanation is found.
Note that this implementation is inspired by the paper by Ying et al., but differs in some aspects.
"""
def __init__(self, gnn, gnnexplaineroptimizer_options=None,
compile_options=None, fit_options=None, **kwargs):
"""Constructs a GNNExplainer instance for the given `gnn`.
Args:
gnn: An instance of a class which implements the `GNNInterface`.
gnnexplaineroptimizer_options (dict, optional): Parameters in this dict are forwarded to the constructor
of the `GNNExplainerOptimizer` (see docstring of `GNNExplainerOptimizer.__init__`).
Defaults to {}.
compile_options (dict, optional): Parameters in ths dict are forwarded to the `keras.Model.compile` method
of the `GNNExplainerOpimizer`. Can be used to customize the optimization process of the
`GNNExplainerOptimizer`.
Defaults to {}.
fit_options (dict, optional): Parameters in ths dict are forwarded to the `keras.Model.fit` method
of the `GNNExplainerOpimizer`.
Defaults to {}.
"""
if gnnexplaineroptimizer_options is None:
gnnexplaineroptimizer_options = {}
if compile_options is None:
compile_options = {}
if fit_options is None:
fit_options = {}
self.gnn = gnn
self.gnnx_optimizer = None
self.graph_instance = None
self.gnnexplaineroptimizer_options = gnnexplaineroptimizer_options
# We need to save options as serialized version to recreate the optimizer on multiple explain calls.
if "optimizer" in compile_options:
if isinstance(compile_options["optimizer"], ks.optimizers.Optimizer):
compile_options["optimizer"] = ks.saving.serialize_keras_object(compile_options["optimizer"])
self.compile_options = compile_options
self.fit_options = fit_options
def explain(self, graph_instance, output_to_explain=None, inspection=False, **kwargs):
"""Finds the masks to the decision of the `self.gnn` on the given `graph_instance`.
This method does not have a return value, but only has side effects.
To get the explanation which was found, call `get_explanation` after calling this method.
This method just instantiates a `GNNExplainerOptimizer`,
which then finds the masks for the explanation via gradient descent.
Args:
graph_instance: The graph input to the GNN to which an explanation should be found.
output_to_explain (optional): Set this parameter to the output which should be explained.
By default the GNNExplainer explains the output the `self.gnn` on the given `graph_instance`.
Defaults to None.
inspection (optional): If `inspection` is set to True this function will return information
about the optimization process in a dictionary form.
Be aware that inspections results in longer runtimes.
Defaults to False.
"""
self.graph_instance = graph_instance
# Add inspection callback to fit options, if inspection is True
fit_options = self.fit_options.copy()
if inspection:
inspection_callback = self.InspectionCallback(self.graph_instance)
if 'callbacks' in self.fit_options.keys():
fit_options['callbacks'].append(inspection_callback)
else:
fit_options['callbacks'] = [inspection_callback]
# Set up GNNExplainerOptimizer and optimize with respect to masks
gnnx_optimizer = GNNExplainerOptimizer(
self.gnn, graph_instance, **self.gnnexplaineroptimizer_options)
self.gnnx_optimizer = gnnx_optimizer
if output_to_explain is not None:
if gnnx_optimizer._output_to_explain_as_variable:
gnnx_optimizer.output_to_explain.assign(output_to_explain)
else:
gnnx_optimizer.output_to_explain = output_to_explain
gnnx_optimizer.compile(**self.compile_options)
# Build gnnx_optimizer with example graph instance.
gnnx_optimizer.predict(graph_instance, steps=1)
gnnx_optimizer.fit(x=graph_instance, y=gnnx_optimizer.output_to_explain, **fit_options)
# Read out information from inspection_callback
if inspection:
dict_fields = ['predictions',
'total_loss',
'edge_mask_loss',
'feature_mask_loss',
'node_mask_loss']
inspection_information = {}
for field in dict_fields:
if hasattr(inspection_callback, field) and len(getattr(inspection_callback, field)) > 0:
inspection_information[field] = getattr(inspection_callback, field)
return inspection_information
def get_explanation(self, **kwargs):
"""Returns the explanation (derived from the learned masks) to a decision on the graph,
which was passed to the `explain` method before.
Important: The `explain` method should always be called before calling this method.
Internally this method just calls the `GNNInterface.get_explanation` method
implemented by the `self.gnn` with the masks found by the `GNNExplainerOptimizer` as parameters.
Raises:
Exception: If the `explain` method is not called before, this method raises an Exception.
Returns:
The explanation which is returned by `GNNInterface.get_explanation` implemented by the `self.gnn`,
parametrized by the learned masks.
"""
if self.graph_instance is None or self.gnnx_optimizer is None:
raise Exception(
"You must first call explain on the GNNExplainer instance.")
edge_mask = self.gnnx_optimizer.get_mask("edge")
feature_mask = self.gnnx_optimizer.get_mask("feature")
node_mask = self.gnnx_optimizer.get_mask("node")
return self.gnn.get_explanation(self.graph_instance,
edge_mask,
feature_mask,
node_mask, **kwargs)
def present_explanation(self, explanation, **kwargs):
"""Takes an explanation, which was generated by `get_explanation` and presents it.
Internally this method just calls the `GNNInterface.present_explanation` method
implemented by the `self.gnn`.
Args:
explanation: The explanation (obtained by `get_explanation`) which should be presented.
Returns:
A presentation of the given explanation.
"""
return self.gnn.present_explanation(explanation, **kwargs)
class InspectionCallback(ks.callbacks.Callback):
"""Callback class to get the inspection information,
if 'inspection' is set to true for the 'GNNExplainer.explain' method.
"""
def __init__(self, graph_instance):
super(GNNExplainer.InspectionCallback, self).__init__()
self.graph_instance = graph_instance
self.predictions = []
self.total_loss = []
self.edge_mask_loss = []
self.feature_mask_loss = []
self.node_mask_loss = []
def on_epoch_begin(self, epoch, logs=None):
masked = ops.convert_to_numpy(self.model.call(self.graph_instance))[0]
self.predictions.append(masked)
def on_epoch_end(self, epoch, logs=None):
"""After epoch."""
if self.model.edge_mask_loss_weight > 0:
self.edge_mask_loss.append(ops.convert_to_numpy(self.model._metric_edge_tracker.result()))
self.model._metric_edge_tracker.reset_state()
if self.model.feature_mask_loss_weight > 0:
self.feature_mask_loss.append(ops.convert_to_numpy(self.model._metric_feature_tracker.result()))
self.model._metric_feature_tracker.reset_state()
if self.model.node_mask_loss_weight > 0:
self.node_mask_loss.append(ops.convert_to_numpy(self.model._metric_node_tracker.result()))
self.model._metric_node_tracker.reset_state()
self.total_loss.append(logs['loss'])
class GNNExplainerOptimizer(ks.Model):
"""The `GNNExplainerOptimizer` solves the optimization problem which is used to find masks,
which then can be used to explain decisions by GNNs.
"""
_output_to_explain_as_variable = False
def __init__(self, gnn_model, graph_instance,
edge_mask_loss_weight=1e-4,
edge_mask_norm_ord=1,
feature_mask_loss_weight=1e-4,
feature_mask_norm_ord=1,
node_mask_loss_weight=0.0,
node_mask_norm_ord=1,
**kwargs):
"""Constructs a `GNNExplainerOptimizer` instance with the given parameters.
Args:
gnn_model (GNNInterface): An instance of a class which implements the methods of the `GNNInterface`.
graph_instance: The graph to which the masks should be found.
edge_mask_loss_weight (float, optional): The weight of the edge mask loss term in the optimization problem.
Defaults to 1e-4.
edge_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the edge mask.
Smaller values encourage sparser masks.
Defaults to 1.
feature_mask_loss_weight (float, optional): The weight of the feature mask loss term in the optimization
problem.
Defaults to 1e-4.
feature_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the feature mask.
Smaller values encourage sparser masks.
Defaults to 1.
node_mask_loss_weight (float, optional): The weight of the node mask loss term in the optimization problem.
Defaults to 0.0.
node_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the feature mask.
Smaller values encourage sparser masks.
Defaults to 1.
"""
super(GNNExplainerOptimizer, self).__init__(**kwargs)
self.gnn_model = gnn_model
self._metric_node_tracker = ks.metrics.Mean(name="mask_loss")
self._metric_edge_tracker = ks.metrics.Mean(name="mask_loss")
self._metric_feature_tracker = ks.metrics.Mean(name="mask_loss")
self._edge_mask_dim = self.gnn_model.get_number_of_edges(
graph_instance)
self._feature_mask_dim = self.gnn_model.get_number_of_node_features(
graph_instance)
self._node_mask_dim = self.gnn_model.get_number_of_nodes(
graph_instance)
self.edge_mask = self.add_weight(
name='edge_mask',
shape=(self._edge_mask_dim, 1),
initializer=ks.initializers.Constant(
value=5.),
dtype=self.dtype,
trainable=True
)
self.feature_mask = self.add_weight(
name='feature_mask',
shape=(self._feature_mask_dim, 1),
initializer=ks.initializers.Constant(
value=5.),
dtype=self.dtype,
trainable=True
)
self.node_mask = self.add_weight(
name='node_mask',
shape=(self._node_mask_dim, 1),
initializer=ks.initializers.Constant(
value=5.),
dtype=self.dtype,
trainable=True
)
output_to_explain = gnn_model.predict(graph_instance)
if self._output_to_explain_as_variable:
self.output_to_explain = self.add_weight(
name='output_to_explain',
shape=output_to_explain.shape,
initializer=ks.initializers.Constant(0.),
dtype=output_to_explain.dtype,
trainable=False
)
self.output_to_explain.assign(output_to_explain)
else:
self.output_to_explain = ops.stop_gradient(ops.convert_to_tensor(output_to_explain))
# Configuration Parameters
self.edge_mask_loss_weight = edge_mask_loss_weight
self.edge_mask_norm_ord = edge_mask_norm_ord
self.feature_mask_loss_weight = feature_mask_loss_weight
self.feature_mask_norm_ord = feature_mask_norm_ord
self.node_mask_loss_weight = node_mask_loss_weight
self.node_mask_norm_ord = node_mask_norm_ord
def call(self, graph_input, training: bool = False, **kwargs):
"""Call GNN model.
Args:
graph_input: Graph input.
training (bool): If training mode. Default is False.
Returns:
Tensor: Masked prediction of GNN model.
"""
edge_mask = self.get_mask("edge")
feature_mask = self.get_mask("feature")
node_mask = self.get_mask("node")
y_pred = self.gnn_model.masked_predict(graph_input, edge_mask, feature_mask, node_mask, training=training)
# edge_mask loss
if self.edge_mask_loss_weight > 0:
loss = norm(ops.sigmoid(self.edge_mask), ord=self.edge_mask_norm_ord) * self.edge_mask_loss_weight
self.add_loss(loss)
self._metric_edge_tracker.update_state([loss])
# feature_mask loss
if self.feature_mask_loss_weight > 0:
loss = norm(ops.sigmoid(self.feature_mask), ord=self.feature_mask_norm_ord) * self.feature_mask_loss_weight
self.add_loss(loss)
self._metric_feature_tracker.update_state([loss])
# node_mask loss
if self.node_mask_loss_weight > 0:
loss = norm(ops.sigmoid(self.node_mask), ord=self.node_mask_norm_ord) * self.node_mask_loss_weight
self.add_loss(loss)
self._metric_node_tracker.update_state([loss])
return y_pred
def get_mask(self, mask_identifier):
if mask_identifier == "edge":
return self._get_mask(self.edge_mask, self.edge_mask_loss_weight)
elif mask_identifier == "feature":
return self._get_mask(self.feature_mask, self.feature_mask_loss_weight)
elif mask_identifier == "node":
return self._get_mask(self.node_mask, self.node_mask_loss_weight)
raise Exception("mask_identifier must be 'edge', 'feature' or 'node'")
def _get_mask(self, mask, weight):
if weight > 0:
return ops.sigmoid(mask)
return ops.ones_like(mask)