-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_expressions.py
executable file
·451 lines (385 loc) · 14.9 KB
/
generate_expressions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
#!/usr/bin/env python
import os
import sys
import threading
import random
import pickle
import traceback
from functools import wraps
from datetime import datetime
from multiprocessing import TimeoutError
from tqdm.auto import tqdm
from joblib import Parallel
from joblib import delayed
import sympy as sp
import numpy as np
from posthoceval.models.synthetic import generate_additive_expression
from posthoceval.models.synthetic import valid_variable_domains
from posthoceval.models.synthetic import as_random_state
from posthoceval.utils import tqdm_parallel
from posthoceval.utils import dict_product
from posthoceval.utils import prod
# === DEBUG ===
from posthoceval.profile import profile
from posthoceval.profile import mem_profile
from posthoceval.profile import set_profile
# === DEBUG ===
from posthoceval.results import ExprResult
_RUNNING_PERIODICITY_IDS = {}
def periodicity_wrapper(func):
"""A hacky fix for https://github.com/sympy/sympy/issues/20566"""
ret_val = [None]
raise_error = [False]
other_exception = [None]
@wraps(func)
def wrapper(*args, _child_=None, **kwargs):
ident = threading.get_ident()
ident_present = ident in _RUNNING_PERIODICITY_IDS
if _child_ is None:
_child_ = ident_present
if _child_ or ident_present:
if not ident_present:
_RUNNING_PERIODICITY_IDS[ident] = 0
if sys.getrecursionlimit() < _RUNNING_PERIODICITY_IDS[ident]:
raise_error[0] = True
sys.exit()
try:
ret = func(*args, **kwargs)
except Exception as e:
raise_error[0] = True
other_exception[0] = e
sys.exit()
_RUNNING_PERIODICITY_IDS[ident] += 1
ret_val[0] = ret
return ret
else:
kwargs['_child_'] = True
# print(f'Enter new periodicity thread {ident}')
thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs)
thread.start()
thread_ident = thread.ident
thread.join()
# print(f'Exit periodicity thread {ident}')
if thread_ident in _RUNNING_PERIODICITY_IDS:
del _RUNNING_PERIODICITY_IDS[thread_ident]
if raise_error[0]:
if other_exception[0] is not None:
raise other_exception[0]
raise RecursionError(
f'Maximum recursions ({sys.getrecursionlimit()}) in '
f'{func} exceeded!'
)
return ret_val[0]
return wrapper
# 🔧 MONKEY PATCH 🔧
# _...._
# .-. /
# /o.o\ ):.\
# \ / `- .`--._
# // / `-.
# '...\ . `.
# `--''. ' `.
# .' .' `-.
# .-' /`-.._ \
# .' _.' : .-'"'/
# | _,--` .' .' /
# \ \ / .' /
# \/// | ' | /
# \ ( `. ``-.
# \ \ `._ \
# _.-` ) .' )
# `.__.-' .-' _-.-'
# `.__,'
# ascii credit: https://www.asciiart.eu/animals/monkeys
sp.periodicity = sp.calculus.util.periodicity = sp.calculus.periodicity = \
periodicity_wrapper(sp.periodicity)
def tqdm_write(*args, sep=' ', **kwargs):
tqdm.write(sep.join(map(str, args)), **kwargs)
@profile
@mem_profile
def generate_expression(symbols, seed, verbose=0, timeout=None, **kwargs):
"""kwargs: see `generate_additive_expression`"""
# sympy uses python random module in spots, set seed for reproducibility
random.seed(seed, version=2)
# I can't prove it but I think sympy also uses default numpy generator in
# spots...
np.random.seed(seed)
# reproducibility, reseeded per job
rs = as_random_state(seed)
interval = sp.Interval(-1, +1)
validate_kwargs = {'interval': interval, 'timeout': timeout,
'verbose': verbose}
tries = 0
while True:
tries += 1
expr = None
tqdm_write('Generating expression...')
try:
expr = generate_additive_expression(
symbols, seed=rs, validate=True,
validate_kwargs=validate_kwargs, **kwargs
)
tqdm_write('Attempting to find valid domains...')
# some or all terms redundant with validate in expr gen,
# but results should be cached (with same args)
domains = valid_variable_domains(expr, fail_action='error',
**validate_kwargs)
except (RuntimeError, RecursionError, TimeoutError) as e:
if expr is None:
tqdm_write('Failed to find domains for:')
tqdm_write(sp.pretty(expr))
else:
tqdm_write('Wow...failed to generate expression...')
exc_lines = traceback.format_exception(
*sys.exc_info(), limit=None, chain=True)
for line in exc_lines:
tqdm_write(line, file=sys.stderr, end='')
else:
break
tqdm_write(f'Generated valid expression in {tries} tries.')
tqdm_write(sp.pretty(expr))
return ExprResult(
symbols=symbols,
expr=expr,
domains=domains,
state=rs.__getstate__(), # TODO: get state before run.... -_-
kwargs=kwargs
)
def run(n_feats_range, n_runs, out_dir, seed, kwargs, n_jobs=-1, timeout=30):
os.makedirs(out_dir, exist_ok=True)
# default kwargs
default_kwargs = dict(
n_main=None,
n_uniq_main=None,
n_interaction=0,
n_uniq_interaction=None,
interaction_ord=None,
n_dummy=0,
pct_nonlinear=None, # default is 0.5
nonlinear_multiplier=None, # default depends on pct_nonlinear
nonlinear_shift=0,
nonlinear_skew=0,
nonlinear_interaction_additivity=.5,
nonlinear_single_multi_ratio='balanced',
nonlinear_single_arg_ops=None,
nonlinear_single_arg_ops_weights=None,
nonlinear_multi_arg_ops=None,
nonlinear_multi_arg_ops_weights=None,
linear_multi_arg_ops=None,
linear_multi_arg_ops_weights=None,
)
total_expressions = prod((
n_runs, len(n_feats_range), *(len(v) for v in kwargs.values())
))
with tqdm_parallel(tqdm(desc='Expression Generation',
total=total_expressions)) as pbar:
import inspect
# TODO: deal with this....
inspect.builtins.print = tqdm_write
def jobs():
nonlocal seed
for n_feat in n_feats_range:
for kw_val in dict_product(kwargs):
job_kwargs = default_kwargs.copy()
job_kwargs.update(kw_val)
symbols = sp.symbols(f'x1:{n_feat + 1}', real=True)
tqdm_write(f'{n_feat} features, generate expr with:')
tqdm_write(job_kwargs)
for _ in range(n_runs):
yield delayed(generate_expression)(
symbols, seed, timeout=timeout, **job_kwargs)
# increment seed (don't have same RNG state per job)
seed += 1
if n_jobs == 1:
# TODO: this doesn't update tqdm
results = [f(*a, **kw) for f, a, kw in jobs()]
else:
results = Parallel(n_jobs=n_jobs)(jobs())
param_str = '-vs-'.join(k for k in kwargs.keys())
now_str = datetime.now().isoformat(timespec='seconds').replace(':', '_')
out_file = os.path.join(out_dir,
f'generated_expressions_{param_str}_{now_str}.pkl')
print('Saving results to', out_file)
with open(out_file, 'wb') as f:
pickle.dump(
results, f,
protocol=4 # 4 is compatible with python 3.3+, 5 with 3.8+
)
if __name__ == '__main__':
import argparse
import re
from math import ceil
from math import sqrt
rx_range_t = re.compile(
# a: int/float
r'^\s*'
r'([-+]?\d+\.?(?:\d+)?|(?:\d+)?\.?\d+)'
r'\s*'
# b: int/float (optional, default: None, range comprises `a` only)
r'(?:,\s*'
r'([-+]?\d+\.?(?:\d+)?|(?:\d+)?\.?\d+)'
r'\s*'
# n: int (optional, default: infer)
r'(?:,\s*'
r'(\d+)'
r'\s*)?'
# scale: str (optional, default: linear)
r'(?:,\s*(log|linear))?'
r'\s*)?'
r'$'
)
range_pattern = ('"a[,b[,n][,log|linear]]" e.g. "1,10" or "-.5,.9,10,log" '
'or "0.5" or "1,10,5"')
def range_type(value):
m = rx_range_t.match(value)
if m is None:
raise argparse.ArgumentTypeError(
f'{value} does not match pattern {range_pattern}'
)
a, b, n, scale = m.groups()
if '.' in a or (b and '.' in b):
dtype = float
else:
dtype = int
if n is not None:
n = int(n)
if n < 1:
raise argparse.ArgumentTypeError(
f'{value} contains an invalid range size ({n} cannot be '
f'less than 1)'
)
if scale is None:
scale = 'linear'
a = float(a)
if b:
b = float(b)
if a >= b:
raise argparse.ArgumentTypeError(
f'{value} is an invalid range ({a} is not less than {b})'
)
return a, b, n, scale, dtype
def arg_val_to_range(a, b, n, scale, inferred_dtype, dtype):
is_int = (dtype == 'int')
inferred_int = (dtype == 'infer' and inferred_dtype is int)
range_msg = f'[{a},{b}]'
if is_int:
int_msg = ('Range was explicitly specified as integer, however, '
'the {} value {} is not an integer in range ' +
range_msg)
if not a.is_integer():
sys.exit(int_msg.format('a', a))
if not (b is None or b.is_integer()):
sys.exit(int_msg.format('b', b))
if b is None:
if is_int or inferred_int:
a = int(a)
if n is not None and n != 1:
sys.exit(f'Specified single value as the range ({a}) but the '
f'size is not 1 ({n}).')
return [a]
if inferred_int:
is_int = ((n is None) or ((b - a) >= (n / 2)))
print(f'Inferred range {range_msg} as '
f'a{"n int" if is_int else " float"} interval.')
if scale == 'linear':
if n is None:
if not is_int:
sys.exit(f'Cannot infer the number of samples from a '
f'space with float dtype for range {range_msg}')
# otherwise
n = int(b - a + 1)
space = np.linspace(a, b, n)
elif scale == 'log':
if n is None:
sys.exit(f'Cannot use a log space without a defined number of '
f'samples for range {range_msg}')
space = np.geomspace(a, b, n)
else:
raise ValueError(scale)
if is_int:
space = np.around(space).astype(int)
return space
def main():
parser = argparse.ArgumentParser( # noqa
description='Generate expressions and save to file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--n-runs', type=int, required=True,
help='Number of runs (expressions generated) per data point'
)
parser.add_argument(
'--n-feats-range', type=range_type, required=True,
help=f'Range of number of features in expressions. '
f'Expected format: {range_pattern}'
)
# TODO: interaction_ord, all ops args
parser.add_argument(
'--kwarg', required=True, action='append',
choices=['n_main', 'n_uniq_main', 'n_interaction',
'n_uniq_interaction', 'interaction_ord', 'n_dummy',
'pct_nonlinear', 'nonlinear_multiplier', 'nonlinear_shift',
'nonlinear_skew', 'nonlinear_interaction_additivity',
'nonlinear_single_multi_ratio'],
help='Name of the expression generation parameter that value range '
'refers to'
)
parser.add_argument(
'--kwarg-range', required=True, type=range_type, action='append',
help=f'Range for kwarg. Expected format: {range_pattern}'
)
parser.add_argument(
'--kwarg-dtype', default='infer', action='append',
choices=('infer', 'int', 'float'),
help=f'dtype for kwarg'
)
default_out_dir = os.path.join(
os.path.dirname(__file__), 'experiment_data', 'expr')
parser.add_argument(
'--out-dir', '-O', default=default_out_dir,
help='Output directory to save generated expressions'
)
parser.add_argument(
'--n-jobs', '-j', default=-1, type=int,
help='Number of jobs (parallel)'
)
parser.add_argument(
'--seed', default=42, type=int,
help='Seed for reproducibility. Technically the starting seed '
'from which each seed is derived per job'
)
parser.add_argument(
'--profile', action='store_true',
help='Profile this run'
)
args = parser.parse_args()
if len(args.kwarg_range) != len(args.kwarg):
sys.exit('The arguments --kwarg and --kwarg-range '
'must all have the same number of arguments. Received: '
f'{len(args.kwarg)}, {len(args.kwarg_range)}')
set_profile(args.profile)
if isinstance(args.kwarg_dtype, str):
kwarg_dtype = [args.kwarg_dtype] * len(args.kwarg)
elif len(args.kwarg_dtype) == len(args.kwarg):
kwarg_dtype = args.kwarg_dtype
elif len(args.kwarg_dtype) == 1:
print('Using provided --kwarg-dtype for all kwarg arguments')
kwarg_dtype = args.kwarg_dtype * len(args.kwarg)
else:
sys.exit('Provided --kwarg-dtype must be the same size as --kwarg '
'arguments, or a single value to be applied to all '
'--kwarg arguments.')
n_feats_range = arg_val_to_range(*args.n_feats_range, dtype='int')
kwargs = {}
for kwarg, range_t, dtype in zip(
args.kwarg, args.kwarg_range, kwarg_dtype):
kwargs[kwarg] = arg_val_to_range(*range_t, dtype=dtype)
run(
n_feats_range=n_feats_range,
n_runs=args.n_runs,
out_dir=args.out_dir,
n_jobs=args.n_jobs,
seed=args.seed,
kwargs=kwargs
)
main()