# default_exp models
This module implements a very lightweght version of pyhf-like model building. For now, there are some hard-coded numbers (bounds, init) that help with the three gaussian blobs demonstration.
#export
import jax
import jax.numpy as jnp
import pyhf
pyhf.set_backend("jax")
# class-based
class _Config(object):
def __init__(self):
self.poi_index = 0
self.npars = 2
def suggested_init(self):
return jnp.asarray([1.0, 1.0])
def suggested_bounds(self):
return jnp.asarray([jnp.asarray([0.0, 10.0]), jnp.asarray([0.0, 10.0])])
class Model(object):
"""Dummy class to mimic the functionality of `pyhf.Model`."""
def __init__(self, spec):
self.sig, self.nominal, self.uncert = spec
self.factor = (self.nominal / self.uncert) ** 2
self.aux = 1.0 * self.factor
self.config = _Config()
def expected_data(self, pars, include_auxdata=True):
mu, gamma = pars
expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])
aux_data = jnp.asarray([self.aux])
return jnp.concatenate([expected_main, aux_data])
def logpdf(self, pars, data):
maindata, auxdata = data
main, _ = self.expected_data(pars)
_, gamma = pars
main = pyhf.probability.Poisson(main).log_prob(maindata)
constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
# sum log probs over bins
return jnp.asarray([jnp.sum(main + constraint, axis=0)])
def hepdata_like(signal_data, bkg_data, bkg_uncerts, batch_size=None):
"""Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`."""
return Model([signal_data, bkg_data, bkg_uncerts])
Let's build an example model, and get gradients of the likelihood function with respect to the model parameters:
import jax
import jax.numpy as jnp
import neos
from neos.models import hepdata_like
# three-bin model example data
sig = jnp.asarray([20, 40, 3])
bkg = jnp.asarray([40, 20, 3])
uncert = jnp.asarray([3, 3, 3])
# construct model
m = hepdata_like(sig, bkg, uncert)
d = m.expected_data([1.0, 1.0])
# need scalar output (logpdf returns array w/ one element)
def logpdf_scalar(pars):
return m.logpdf(pars, d)[0]
# check we can get gradients!
jax.value_and_grad(logpdf_scalar)([2.0, 1.0])
(DeviceArray(-27.74804929, dtype=float64), [DeviceArray(-22., dtype=float64), DeviceArray(-19., dtype=float64)])