-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpdf.py
725 lines (568 loc) · 24.5 KB
/
pdf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
# -*- coding: utf-8 -*-
""" PDF base class
The pdf base class is the main interface to ROOT. This class basically serves as a wrapper around AbsPDF of
the Roofit package and takes, monitors and initiates observables and parameters.
Examples:
How to add a new wrapper to a ROOT.AbsPdf:
.. code-block:: python
class MyPDF(PDF):
def __init__(self, x, param, name='MyPDF'):
super(MyPDF, self).__init__(name=name, **kwds)
x = self.add_observable(x)
param1 = self.add_parameter(param)
self.roo_pdf = ROOT.MyPDF(self.name, self.title, x, param1)
"""
from __future__ import print_function
import ROOT
from .utilities import ClassLoggingMixin, AttrDict
from .data import df2roo
from .plotting import fast_plot
from .observables import create_roo_variable
class PDF(ClassLoggingMixin, object):
""" Base class for the ROOT.RooFit wrapper
Attributes:
name (str): Name of the pdf
title (str) Title of the pdf, displayed automatically in the legend
observables (dict): Dictionary of the fit observables
parameters (dict): Dictionary of the fit parameters
roo_pdf (:obj:`ROOT.RooAbsPdf`): Basis ROOT.RooFit object of the wrapper
last_data (:obj:`ROOT.RooAbsData`): Reference to last fit data
last_fit: Reference to last fit result
Todo:
* Maybe remove observables key
* Add convolution to @ overwrite?
"""
#: Storage of the roo fit data and last fit
last_data = None
last_fit = None
#: bool: Fit options for minuit
use_minos = False
use_hesse = True
use_extended = False
use_sumw2error = True
def __init__(self, name, observables=None, title=None, **kwds):
""" Init of the PDF class
Args:
name (:obj:`str`): Name of the model
observables (:obj:`list` of :obj:`ROOT.RooRealVar`, optional): Deprecated
title (:obj:`str`): Title of the model
**kwds: >> May be removed
"""
super(PDF, self).__init__(**kwds)
#: Unique identifier of the PDF
self.name = name
#: Title of the PDF
self.title = title
if self.title is None:
self.title = self.name
#: dict(str->ROOT.RooRealVar) - Input variable form the data frame
self.observables = AttrDict()
if observables:
for observable in observables:
self.add_observable(observable)
#: dict(str->ROOT.RooRealVar) - Fitted parameters from the fit procedure
self.parameters = AttrDict()
self.parameter_names = AttrDict()
#: RooAbsPDF
self.roo_pdf = None
self.init_pdf()
#: int: Flag for the ROOT output
self.print_level = -1
def __call__(self):
""" Call overwrite
Returns:
ROOT.RooAbsPdf base model
"""
return self.roo_pdf
def __add__(self, other):
""" Add operator overwrite
Args:
other (:obj:`PDF`): Pdf to be added
Returns:
AddPdf of the two PDF objects
"""
from .composites import AddPdf
return AddPdf([self, other])
def __mul__(self, other):
""" Mul operator overwrite
Returns:
ProdPdf
"""
from .composites import ProdPdf
return ProdPdf([self, other])
def init_pdf(self):
""" Initiate attributes for parameters
"""
for p in self.parameters:
self.__setattr__(p, self.parameters[p])
def add_parameter(self, param_var, param_name=None, final_name=None, **kwds):
""" Add fit parameter
Args:
param_var (list or ROOT.RooRealVar): Initialisation of the parameter as list or ROOT object
param_name (:obj:`str`): Name of the parameter within the object (Not within ROOT namespace!)
final_name (:obj:`str`, optional): Name if the parameter within PDF and ROOT namespace
**kwds: create_roo_variable keywords
Returns:
ROOT.RooRealVar reference to fit parameter
"""
if final_name is None:
assert param_name is not None, "Please specify a parameter name"
name = self.name + '_' + param_name
else:
name = final_name
roo_param = create_roo_variable(param_var, name=name, **kwds)
self.parameters[param_name] = roo_param
self.parameter_names[param_name] = name
self.__setattr__(param_name, roo_param)
return self.parameters[param_name]
def add_observable(self, observable_var, **kwds):
""" Addidng a observable to the PDF
Observables are used in the PDF class to convert relevant columns in pandas.DataFrames
Args:
observable_var (list or ROOT.RooRealVar): Initialisation of the observable as list or ROOT object
**kwds: create_roo_variable keywords
Returns:
ROOT.RooRealVar reference to fit observable
"""
if isinstance(observable_var, list) or isinstance(observable_var, tuple):
if not isinstance(observable_var[0], str):
self.warn("WARNING : choosing automatic variable name 'x'")
roo_observable = create_roo_variable(observable_var, **kwds)
name = roo_observable.GetName()
self.observables[name] = roo_observable
return self.observables[name]
def get_fit_data(self, df, weights=None, observables=None, nbins=None, *args, **kwargs):
""" Convert pandas.DataFrame to ROOT.RooAbsData containing only relevant columns
Args:
df (:obj:`DataFrame` or :obj:`array`): Fit data
weights (:obj:`str` or :obj:`array`, optional): Column name of weights or wrights data
observables (dict, optional): Dictionary of the observables to be converted
nbins (int, optional): Number of bins, created ROOT.RooDataHist instead
Returns:
ROOT.RooDataSet or ROOT.RooDataHist of relevant columns and rows of the input data
"""
if observables is None:
observables = self.observables
roo_data = df2roo(df, observables=observables, weights=weights, bins=nbins, *args, **kwargs)
return roo_data
def fit(self, df, weights=None, nbins=None, *args, **kwargs):
""" Fit a pandas or numpy data to the PDF
Args:
df (:obj:`DataFrame` or :obj:`array`): Fit data
weights (:obj:`str` or :obj:`array`, optional): Column name of weights or wrights data
nbins (int, optional): Number of bins, created ROOT.RooDataHist instead
Returns:
ROOT.RooFitResult of the fit
"""
self.logger.debug("Fitting")
self.last_data = self.get_fit_data(df, weights=weights, nbins=nbins, )
self._fit(self.last_data, *args, **kwargs)
return self.last_fit
def _before_fit(self, *args, **kwargs):
""" Template function before fit
This function is called before the fit and can be overwritten for certain use cases.
"""
pass
def _fit(self, data_roo, *args, **kwargs):
""" Internal fit function
Args:
data_roo (ROOT.RooDataSet): Dataset to fit on the internal RooAbsPdf
"""
self._before_fit()
self.logger.info("Performing fit")
self.last_fit = self.roo_pdf.fitTo(data_roo,
ROOT.RooFit.Save(True),
# ROOT.RooFit.Warnings(ROOT.kFALSE),
ROOT.RooFit.PrintLevel(self.print_level),
ROOT.RooFit.PrintEvalErrors(-1),
ROOT.RooFit.Extended(self.use_extended),
ROOT.RooFit.SumW2Error(self.use_sumw2error),
ROOT.RooFit.Minos(self.use_minos),
ROOT.RooFit.Hesse(self.use_hesse), *args, **kwargs)
def plot(self, filename, data=None, observable=None, *args, **kwargs):
""" Default plotting function
Args:
filename (str): Name of the output file. Suffix determines file type.
data (DataFrame or ROOT.RooDataSet, optional): Data to be plotted in the fit
observable (:obj:`ROOT.RooAbsReal` or str, optional): In case of multiple dimensions draw
projection to specified observable.
*args: Arguments for fast_plot
**kwargs: Keyword arguments for fast_plot
"""
if self.last_data is None and data is None:
self.logger.error("There is no fit data")
return
if data is not None:
import pandas as pd
if isinstance(data, pd.DataFrame):
data = self.get_fit_data(data)
# suffix
suffix = filename.split('.')[-1]
# remove suffix from filename
if '.' + suffix in filename:
filename = filename.split('.' + suffix)[0]
# Find the observable in case there are more observables
if observable is None:
if len(self.observables) == 1:
observable = self.get_observable()
else:
for o in self.observables:
try:
self.logger.info('Plotting ' + o)
self._plot(filename + '_' + o + '.' + suffix, self.observables[o], data, *args, **kwargs)
except AttributeError:
self.logger.error("There was a plotting error")
return
else:
if type(observable) is str:
for o in self.observables:
if observable == self.observables[o].GetName():
observable = self.observables[o]
break
self._plot(filename + '.' + suffix, observable, data, *args, **kwargs)
def _plot(self, filename, observable, data=None, *args, **kwargs):
""" plot function to be overwritten
Args:
filename (str): Name of the output file. Suffix determines file type.
data (DataFrame or ROOT.RooDataSet, optional): Data to be plotted in the fit
observable (:obj:`ROOT.RooAbsReal` or str, optional): In case of multiple dimensions draw
projection to specified observable.
*args: Arguments for fast_plot
**kwargs: Keyword arguments for fast_plot
"""
if data is None:
data = self.last_data
fast_plot(self.roo_pdf, data, observable, filename, *args, **kwargs)
def _get_var(self, v, as_ufloat=False):
""" Internal getter for parameter values
Args:
v: Parameter name
as_ufloat: Return ufolat object
Returns:
:obj:`tuple` or :obj:`ufloat` mean and error of parameter
"""
mes = self.parameters[v]
val = mes.getVal()
# Now catch RooFormulaVar
try:
err = mes.getError()
except AttributeError:
try:
err = mes.getPropagatedError(self.last_fit)
except TypeError:
err = 0
if not as_ufloat:
return val, err
try:
from uncertainties import ufloat
ret = ufloat(val, err)
return ret
except ImportError:
return val, err
def get(self, parameter=None, as_ufloat=False):
""" Get one of the fitted parameter or print all if None is set
Args:
parameter (str): name of the parameter
as_ufloat (bool, optional): If true return ufloat object, else tuple
Returns:
:obj:`tuple` or :obj:`ufloat` mean and error of parameter
"""
if parameter is None:
for m in self.parameters:
print('{0:18} ==> {1}'.format(m, self._get_var(m, True)))
else:
return self._get_var(parameter, as_ufloat)
def get_observable(self):
""" Get the observables from self.observables
"""
assert len(self.observables) > 0, "There is not obsrvable"
if len(self.observables) > 1:
self.logger.warn('There are more than one observables, returning first')
for value in self.observables.values():
return value
def fix(self, constant=True):
""" Fix all parameters of the PDF
Args:
constant (bool, default=True): Set all parameter constant
"""
for m in self.parameters:
self.logger.debug("Setting %s constant" % m)
self.parameters[m].setConstant(constant)
def constrain(self, sigma, param=None):
""" Constrain parameters within given significance
use only with existing fit result
Args:
sigma (int or float): Interval to constrain parameter is convidence of the error.
param (float, optional): Specify which parameter to constrain
"""
for m in self.parameters:
if param is not None:
if m is not param:
continue
cent = self.parameters[m].getVal()
err = self.parameters[m].getError()
self.parameters[m].setMin(cent - sigma * err)
self.parameters[m].setMax(cent + sigma * err)
def narrow(self, sigma=1):
""" Narrows all parameters within one sigma of the original definition, keeps original limits
Args:
sigma (int or float): Interval to constrain parameter is convidence of the error.
"""
for m in self.parameters:
l = self.parameters[m].getMin()
h = self.parameters[m].getMax()
v = self.parameters[m].getVal()
e = self.parameters[m].getError()
new_h = v+sigma*e if v+sigma*e < h else h
new_l = v - sigma * e if v - sigma * e > l else l
self.parameters[m].setMax(new_h)
self.parameters[m].setMin(new_l)
def randomize_pdf(self, frac=1/6., exceptions=None, only=None):
""" Randomize parameters of a pdf
Args:
frac: with of the Gauss added to each member
exceptions (list): List of excluded parameters
only (list): List of parameters to include
"""
import random
params = self.parameters if only is None else only
for m in params:
if exceptions is not None:
if m in exceptions:
continue
try:
max_ = self.parameters[m].getMax()
min_ = self.parameters[m].getMin()
dist = abs(max_ - min_)
to_add = random.normalvariate(0, dist * frac)
val = self.parameters[m].getVal()
if min_ < (val + to_add) < max_:
# only a small gaussian blur
self.parameters[m].setVal(val + to_add)
else:
self.parameters[m].setVal(random.uniform(min_, max_))
except AttributeError:
self.logger.error("Unable to randomize parameter " + m)
def refit(self, randomize=False, exceptions=None, only=None):
if randomize:
self.randomize_pdf(exceptions=exceptions, only=only)
self._fit(self.last_data)
def check_convergence(self, err_lim_low, err_lim_high, n_refit=20,
only=None, exceptions=None, assym=False, ignore_n=True):
""" Check convergence of PDF and refit
Args:
err_lim_low:
err_lim_high:
n_refit:
only:
exceptions:
assym:
ignore_n:
Returns:
"""
passing = True
suspect = None
suspect_value = None
suspect_error = None
params = only
if only is None:
params = [p for p in self.parameters]
for p in params:
if exceptions is not None:
if p in exceptions:
continue
if ignore_n and 'n_' in p:
continue
p_err = self.parameters[p].getError()
if assym:
p_err = min(abs(self.parameters[p].getErrorHi()), abs(self.parameters[p].getErrorLo()))
if not err_lim_low < p_err < err_lim_high:
passing = False
suspect = p
suspect_value = self.parameters[p].getVal()
suspect_error = p_err
break
if not passing:
self.warn("Fit not converged due to %s (%.4f +-%.4f),"
" try %d refitting " % (suspect, suspect_value, suspect_error, n_refit))
if n_refit == 0:
return False
else:
self.refit(randomize=True, exceptions=exceptions, only=only)
return self.check_convergence(err_lim_low,
err_lim_high,
n_refit - 1,
only=only,
exceptions=exceptions,
assym=assym)
return True
def get_parameters(self, par_name=None):
""" Get values for parameter
Args:
par_name (str, optional): Parameter name
Returns:
list of [mean, error, min, max]
"""
ret = {}
for p in self.parameters:
if par_name is not None:
if par_name not in p:
continue
par = self.parameters[p]
ret[p] = [par.getVal(), par.getError(), par.getMin(), par.getMax()]
return ret
def set_parameter(self, p, params):
""" Set parameter to value, error and limits, Experimental
Args:
p (str): Name of the parameter
params (list): [value, error=optional, min=optional, max=optional]
"""
assert p in self.parameters.keys(), 'Parameter not found'
par = self.parameters[p]
if not isinstance(params, list):
params = [params]
if len(params) == 1:
par.setVal(params[0])
elif len(params) == 2:
par.setVal(params[0])
par.setError(params[1])
elif len(params) == 4:
par.setVal(params[0])
par.setError(params[1])
par.setMin(params[2])
par.setMax(params[3])
else:
self.error("Could not set parameter")
def set_parameters(self, pars):
""" Set parameters in the pdf, Experimental
Args:
pars (list of lists):
"""
if isinstance(pars, list):
for p in pars:
ps1 = p.split('_')
if not len(ps1) >= 2:
self.warn("Parameter %s can not be set" % p)
continue
name_remote = ps1[-1]
for sp in self.parameters:
ps2 = sp.split('_')
if not len(ps2) >= 2:
self.warn("Parameter %s can not be set" % sp)
continue
name_self = ps2[-1]
if name_remote in name_self:
self.debug("Setting parameter %s in %s" % (p, sp))
self.set_parameter(sp, pars[p])
elif isinstance(pars, dict):
for p in pars:
self.debug("Setting parameter %s in %s" % (p, pars[p]))
self.set_parameter(p, pars[p])
else:
self.error("Please provide list or dict")
def get_curve(self, observable=None, norm=1, npoints=1000):
""" Get projection of the pdf curve
Args:
observable (str, optional): Name of the observable
norm (float, optional): normalisation for the pdf, default=1
npoints (int): number of points, default=1000
Returns:
hx, hy : numpy arrays of x and y points of the pdf projection
Examples:
>>> import matplotlib.pyplot as plt
>>> plt.plot(*pdf.get_curve())
"""
import root_numpy
import numpy as np
if observable is not None:
assert isinstance(observable, str), "please specify the name of the observable"
assert observable in self.observables, "observable not found"
else:
observable = self.get_observable().GetName()
h = self.roo_pdf.createHistogram(observable, npoints)
hy, hx = root_numpy.hist2array(h, False, False, True)
# Normalise y
hy = npoints * hy / np.sum(hy)
# center x
hx = (hx[0][:-1] + hx[0][1:])/2.
return hx, np.multiply(hy,norm)
def get_components_curve(self, norm=1, npoints=1000):
""" Get individual (normed) components of a composite pdf
Args:
pdf (pyroofit.AddPdf): the composite pdf
norm (float, optional): normalisation for the pdf, default=1
npoints (int): number of points, default=1000
Returns:
curves (list/False): list of pdfs that composed pdf. If pdf is not composite, returns False
"""
from .composites import AddPdf
if isinstance(self, AddPdf):
total_norm = 0
for n in self.norms:
total_norm += self.norms[n].getVal()
curves = {}
for p in self.pdfs:
curves[p] = self.pdfs[p].get_curve(norm=norm*self.norms[p].getVal()/total_norm)
return curves
return False
def get_fwhm(self, observable=None, npoints=1000):
""" Calculate Full width at half maximum - EXPERIMENTAL
Args:
observable
npoints
Returns:
FWHM
"""
if observable is not None:
assert isinstance(observable, str), "please specify the name of the observable"
assert observable in self.observables, "observable not found"
else:
observable = self.get_observable().GetName()
h = self.roo_pdf.createHistogram(observable, npoints)
bin1 = h.FindFirstBinAbove(h.GetMaximum()/2.)
bin2 = h.FindLastBinAbove(h.GetMaximum()/2.)
fwhm = h.GetBinCenter(bin2) - h.GetBinCenter(bin1)
return fwhm
def get_pull(self, observable, nbins):
"""
Get the values of the pull distribution for a fitted pdf.
RooFit is weird and allows you only to calculate the pulls on Plotable objects via the frame.
In order to do so you need a RootDataSet 'Data' and a Line object 'Model'.
TODO: use the objects observable(s)?
Args:
observable: RooRealVar used for the fit
nbins: number of bins for the pull
Returns:
python list containing the residuals
"""
frame = observable.frame(ROOT.RooFit.Title("Fit Result"), ROOT.RooFit.Bins(nbins))
self.last_data.plotOn(frame, ROOT.RooFit.Name("Data"), ROOT.RooFit.DataError(ROOT.RooAbsData.SumW2))
self.roo_pdf.plotOn(frame, ROOT.RooFit.Name("Model"), ROOT.RooFit.LineColor(1), ROOT.RooFit.Range("Full"))
pulls = frame.pullHist("Data", "Model", True)
pull_buffer = pulls.GetY()
return [pull_buffer[i] for i in range(pulls.GetN())]
def get_sampled_goodness_of_fit(self, observable, nbins=100):
"""
Since most of the fits are unbinned the best way to get a goodness
of fit is to sample the data and the fit function.
This means this chi2 is not a goodness of fit but rather a distance between
fit function and data. For bins->infinity this should converge to 'Goodness of fit'.
Or what I am trying to say: this chi2 is not extracted from the fit.
Two assumptions are made:
- the data is not binned more finely than 100 bins in the given interval
- there is enough data to get a decent population in 100 bins
Args:
observable: RooRealVar used for the fit
nbins: number of bins for the pull
Returns:
chisquared
chisquared/ndf
pvalue
ndf
"""
pull_values = get_pull(observable=observable, nbins=nbins)
ndf = len(pull_values) - len(self.parameters)
chi2 = sum([x*x for x in pull_values])
reduced_chi2 = chi2/ndf
pvalue = ROOT.TMath.Prob(chi2, ndf)
return chi2, reduced_chi2, pvalue, ndf