-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_explainers.py
executable file
·273 lines (228 loc) · 9.1 KB
/
evaluate_explainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#!/usr/bin/env python
"""
evaluate_explainers.py - A PostHocExplainerEvaluation file
Copyright (C) 2020 Zach Carmichael
"""
import os
import sys
# ssshhhhhhhhhhhh
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
# # lazy load tensorflow
# from posthoceval.lazy_loader import LazyLoader
#
# _ = LazyLoader('tensorflow')
from glob import glob
import traceback
from tqdm.auto import tqdm
from joblib import Parallel
from joblib import delayed
import numpy as np
import sympy as sp
from posthoceval.models.synthetic import SyntheticModel
from posthoceval.utils import assert_same_size
from posthoceval.expl_utils import save_explanation
from posthoceval.expl_utils import CompatUnpickler
from posthoceval.explainers.local.shap import KernelSHAPExplainer
from posthoceval.explainers.local.shapr import SHAPRExplainer
from posthoceval.explainers.local.lime import LIMEExplainer
from posthoceval.explainers.local.maple import MAPLEExplainer
from posthoceval.explainers.global_.pdp import PDPExplainer
from posthoceval.explainers import GradCAMExplainer
from posthoceval.explainers import VanillaGradientsExplainer
from posthoceval.explainers import GradientsXInputsExplainer
from posthoceval.explainers import IntegratedGradientsExplainer
from posthoceval.explainers import OcclusionExplainer
from posthoceval.explainers import XRAIExplainer
from posthoceval.explainers import BlurIntegratedGradientsExplainer
from posthoceval.utils import tqdm_parallel
from posthoceval.results import ExprResult
EXPLAINER_MAP = {
'SHAP': KernelSHAPExplainer,
# TODO: SHAPR for each of the conditioned distributions other than
# empirical
'SHAPR': SHAPRExplainer,
'LIME': LIMEExplainer,
'MAPLE': MAPLEExplainer,
'PDP': PDPExplainer,
'GradCAM': GradCAMExplainer,
'GradCAM-Smooth': GradCAMExplainer.smooth_grad,
'VanillaGradients': VanillaGradientsExplainer,
'VanillaGradients-Smooth': VanillaGradientsExplainer.smooth_grad,
'GradientsXInputs': GradientsXInputsExplainer,
'GradientsXInputs-Smooth': GradientsXInputsExplainer.smooth_grad,
'IntegratedGradients': IntegratedGradientsExplainer,
'IntegratedGradients-Smooth': IntegratedGradientsExplainer.smooth_grad,
'Occlusion': OcclusionExplainer,
'XRAI': XRAIExplainer,
'XRAI-Smooth': XRAIExplainer.smooth_grad,
'BlurIG': BlurIntegratedGradientsExplainer,
'BlurIG-Smooth': BlurIntegratedGradientsExplainer.smooth_grad,
}
def explain(explainer_cls, out_filename, expr_result, data_file, max_explain,
seed):
# type hint
expr_result: ExprResult
if os.path.exists(out_filename):
tqdm.write(f'{out_filename} exists, skipping!')
return
tqdm.write('Generating model')
model = SyntheticModel.from_expr(
expr=expr_result.expr,
symbols=expr_result.symbols,
)
tqdm.write('Creating explainer')
explainer = explainer_cls(
model,
seed=seed,
)
tqdm.write(f'Loading data from {data_file}')
data = np.load(data_file)['data']
to_explain = data
if max_explain is not None and max_explain < len(to_explain):
to_explain = to_explain[:max_explain]
try:
tqdm.write('Fitting explainer')
explainer.fit(data)
tqdm.write('Explaining')
explanation = explainer.feature_contributions(to_explain)
except (ValueError, TypeError):
tqdm.write(f'Failed to explain model:')
tqdm.write(sp.pretty(expr_result.expr))
exc_lines = traceback.format_exception(
*sys.exc_info(), limit=None, chain=True)
for line in exc_lines:
tqdm.write(str(line), file=sys.stderr, end='')
return
# save things in parallel
tqdm.write('Saving explanation')
save_explanation(out_filename, explanation)
def run(expr_filename, out_dir, data_dir, max_explain, seed, n_jobs,
start_at=1, step_size=1, explainer='SHAP', debug=False):
""""""
try:
explainer_cls = EXPLAINER_MAP[explainer]
except KeyError:
raise ValueError(
f'{explainer} is not a valid explainer name') from None
basename_experiment = os.path.basename(expr_filename).rsplit('.', 1)[0]
explainer_out_dir = os.path.join(out_dir, basename_experiment, explainer)
os.makedirs(explainer_out_dir, exist_ok=True)
print('Loading', expr_filename, '(this may take a while)')
with open(expr_filename, 'rb') as f:
expr_data = CompatUnpickler(f).load()
print('Loading data')
data_files = glob(os.path.join(data_dir, '*.npz'))
assert_same_size(expr_data, data_files, f'data files (in {data_dir})')
# grab each file ID as integer index
file_ids = [*map(lambda fn: int(os.path.basename(fn).rsplit('.')[0]),
data_files)]
assert len(file_ids) == len({*file_ids}), 'duplicate data file IDs!'
assert min(file_ids) == 0, 'file ID index does not start at 0'
n_results = len(expr_data)
assert max(file_ids) == (n_results - 1), (
f'file ID index does not end with {n_results - 1} (number of results)')
# now that data looks good, just sort the file names so we can zip together
# with the loaded expression data
data_files = [fn for _, fn in sorted(zip(file_ids, data_files),
key=lambda id_fn: id_fn[0])]
# same as indexing as [start_at - 1::step_size]
indices = slice(start_at - 1, None, step_size)
file_ids = range(n_results)[indices]
data_files = data_files[indices]
expr_data = expr_data[indices]
with tqdm_parallel(tqdm(desc=f'Evaluating {explainer}', total=n_results)):
def jobs():
for i, data_file, expr_result in zip(file_ids,
data_files,
expr_data):
out_filename = os.path.join(explainer_out_dir, str(i)) + '.npz'
yield delayed(explain)(
explainer_cls, out_filename, expr_result, data_file,
max_explain, seed
)
if debug: # one iteration
break
if n_jobs == 1 or debug:
# TODO: this doesn't update tqdm
[f(*a, **kw) for f, a, kw in jobs()]
else:
Parallel(n_jobs=n_jobs)(jobs())
if __name__ == '__main__':
import argparse
def positive_int(x):
try:
x = int(x)
except ValueError:
raise argparse.ArgumentTypeError(f'{x} is not a valid integer.')
if x < 1:
raise argparse.ArgumentTypeError(f'{x} must be positive.')
return x
def main():
parser = argparse.ArgumentParser( # noqa
description='Generate explainer explanations of models and save '
'results to file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'expr_filename', help='Filename of the expression pickle'
)
default_out_dir = os.path.join(
os.path.dirname(__file__), 'experiment_data', 'explanations')
parser.add_argument(
'--out-dir', '-O', default=default_out_dir,
help='Output directory to save explanations'
)
parser.add_argument(
'--data-dir', '-D',
help='Data directory where generated data for expr_filename exists'
)
parser.add_argument(
'--explainer', '-X',
choices=[*EXPLAINER_MAP.keys()],
default='SHAP', help='The explainer to evaluate'
)
parser.add_argument(
'--max-explain', type=positive_int,
help='Maximum number of data points to explain per model'
)
parser.add_argument(
'--n-jobs', '-j', default=-1, type=int,
help='Number of jobs to use in generation'
)
parser.add_argument(
'--seed', default=42, type=int,
help='Seed for reproducibility'
)
parser.add_argument(
'--start-at', default=1, type=positive_int,
help='start index (1-indexed) for .pkl file'
)
parser.add_argument(
'--step-size', default=1, type=positive_int,
help='Size of increment for evaluation'
)
parser.add_argument(
'--debug', action='store_true',
help=argparse.SUPPRESS
)
args = parser.parse_args()
data_dir = args.data_dir
if data_dir is None:
data_dir = os.path.join(
os.path.dirname(__file__), 'experiment_data', 'data',
os.path.basename(args.expr_filename).rsplit('.', 1)[0]
)
if not os.path.isdir(data_dir):
sys.exit(f'Could not infer --data-dir (guessed '
f'"{data_dir}". Please supply this argument.')
run(out_dir=args.out_dir,
expr_filename=args.expr_filename,
data_dir=data_dir,
max_explain=args.max_explain,
n_jobs=args.n_jobs,
seed=args.seed,
start_at=args.start_at,
explainer=args.explainer,
step_size=args.step_size,
debug=args.debug)
main()