Skip to content

Commit

Permalink
cast all masks first
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg committed Dec 21, 2024
1 parent c5cbdba commit 6cd7044
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 28 deletions.
13 changes: 9 additions & 4 deletions src/pyhf/modifiers/histosys.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class histosys_combined:
def __init__(
self, modifiers, pdfconfig, builder_data, interpcode='code0', batch_size=None
):
default_backend = pyhf.default_backend

self.batch_size = batch_size
self.interpcode = interpcode
assert self.interpcode in ['code0', 'code2', 'code4p']
Expand Down Expand Up @@ -128,10 +130,13 @@ def __init__(
]
for m in keys
]
self._histosys_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._histosys_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)

if histosys_mods:
self.interpolator = getattr(interpolators, self.interpcode)(
Expand Down
14 changes: 10 additions & 4 deletions src/pyhf/modifiers/lumi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import pyhf
from pyhf import get_backend, events
from pyhf.parameters import ParamViewer

Expand Down Expand Up @@ -60,6 +61,8 @@ class lumi_combined:
op_code = 'multiplication'

def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
default_backend = pyhf.default_backend

self.batch_size = batch_size

keys = [f'{mtype}/{m}' for m, mtype in modifiers]
Expand All @@ -72,10 +75,13 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
)
self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, lumi_mods)

self._lumi_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._lumi_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)
self._precompute()
events.subscribe('tensorlib_changed')(self._precompute)

Expand Down
14 changes: 10 additions & 4 deletions src/pyhf/modifiers/normfactor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import pyhf
from pyhf import get_backend, events
from pyhf.parameters import ParamViewer

Expand Down Expand Up @@ -58,6 +59,8 @@ class normfactor_combined:
op_code = 'multiplication'

def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
default_backend = pyhf.default_backend

self.batch_size = batch_size

keys = [f'{mtype}/{m}' for m, mtype in modifiers]
Expand All @@ -72,10 +75,13 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
parfield_shape, pdfconfig.par_map, normfactor_mods
)

self._normfactor_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._normfactor_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)
self._precompute()
events.subscribe('tensorlib_changed')(self._precompute)

Expand Down
14 changes: 10 additions & 4 deletions src/pyhf/modifiers/normsys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import pyhf
from pyhf import get_backend, events
from pyhf import interpolators
from pyhf.parameters import ParamViewer
Expand Down Expand Up @@ -72,6 +73,8 @@ class normsys_combined:
def __init__(
self, modifiers, pdfconfig, builder_data, interpcode='code1', batch_size=None
):
default_backend = pyhf.default_backend

self.interpcode = interpcode
assert self.interpcode in ['code1', 'code4']

Expand All @@ -97,10 +100,13 @@ def __init__(
]
for m in keys
]
self._normsys_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._normsys_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)

if normsys_mods:
self.interpolator = getattr(interpolators, self.interpcode)(
Expand Down
11 changes: 7 additions & 4 deletions src/pyhf/modifiers/shapefactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,13 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
parfield_shape, pdfconfig.par_map, shapefactor_mods
)

self._shapefactor_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._shapefactor_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)

global_concatenated_bin_indices = [
[[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]]
Expand Down
11 changes: 7 additions & 4 deletions src/pyhf/modifiers/shapesys.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
parfield_shape, pdfconfig.par_map, self._shapesys_mods
)

self._shapesys_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._shapesys_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)
self.__shapesys_info = default_backend.astensor(
[
[
Expand Down
11 changes: 7 additions & 4 deletions src/pyhf/modifiers/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,13 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
parfield_shape, pdfconfig.par_map, self._staterr_mods
)

self._staterror_mask = [
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
self._staterror_mask = default_backend.astensor(
[
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
],
dtype='bool',
)
global_concatenated_bin_indices = [
[[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]]
]
Expand Down

0 comments on commit 6cd7044

Please sign in to comment.