# default_exp transforms
Contains transforms to map from $[-\infty,\infty]$ to a bounded space $[a,b]$ and back.
This module implements two transforms, taken from the minuit optimizer:
$$P_{\mathrm{inf}}=\arcsin \left(2 \frac{P_{\mathrm{bounded}}-a}{b-a}-1\right):~[a,b] \rightarrow [-\infty,\infty]$$$$P_{\mathrm{bounded}}=a+\frac{b-a}{2}\left(\sin P_{\mathrm{inf}}+1\right):~[-\infty,\infty]\rightarrow [a,b] $$The purpose of these is to add stability to the maximum likelihood fits of the model parameters, which are currently done by gradient descent. This is done by allowing the minimization to occur on the real line, and then mapping the result to a value in a 'sensible' interval $[a,b]$ before evaluating the likelihood. You can imagine if this wasnt the case, it's possible that the likelihood may be evaluated with negative model parameters or very extreme values, potentially causing numeric instability in the likelihood or gradient evaluations.
# export
import jax
import jax.numpy as jnp
# avoid those precision errors!
jax.config.update("jax_enable_x64", True)
# export
# [-inf, inf] -> [a,b] (vectors)
def to_bounded_vec(param, bounds):
bounds = jnp.asarray(bounds)
a, b = bounds[:, 0], bounds[:, 1]
return a + (b - a) * 0.5 * (jnp.sin(param) + 1.0)
# [-inf, inf] -> [a,b]
def to_bounded(param, bounds):
a, b = bounds
return a + (b - a) * 0.5 * (jnp.sin(param) + 1.0)
# [-inf, inf] <- [a,b] (vectors)
def to_inf_vec(param, bounds):
bounds = jnp.asarray(bounds)
a, b = bounds[:, 0], bounds[:, 1]
x = (2.0 * param - a) / (b - a) - 1.0
return jnp.arcsin(x)
# [-inf, inf] <- [a,b]
def to_inf(param, bounds):
a, b = bounds
# print(f"a,b: {a,b}")
x = (2.0 * param - a) / (b - a) - 1.0
return jnp.arcsin(x)
import numpy as np
p = jnp.asarray([1.0, 1.0]) # points
b = jnp.asarray([[0.0, 10.0], [0.0, 10.0]]) # bounds
# check if 1 is invariant if we transform to bounded space and back
cond = np.allclose(to_inf(to_bounded(p[0], b[0]), b[0]), p[0])
assert cond, f"{to_inf(to_bounded(p[0], b[0]), b[0])} != {p[0]}"
# check if [1,1] is invariant
cond = np.allclose(to_inf_vec(to_bounded_vec(p, b), b), p)
assert cond, f"{to_inf_vec(to_bounded_vec(p, b), b)} != {p}"
# hide
bounds = jnp.array([[0, 10], [0, 20]])
# check that we map to inf space (i.e. -pi/2 to pi/2)
w = jnp.linspace(0, 10)
x = to_inf(w, bounds[0])
print(f"min: {w.min()}, max: {w.max()}, to inf:")
print(x.min(), x.max())
# check that we can map very large values to bounded space
w = jnp.linspace(-1e10, 1e10, 1001)
x = to_bounded(w, bounds[0])
print(f"min: {w.min()}, max: {w.max()}, to {bounds[0]}:")
print(x.min(), x.max())
assert np.allclose(
np.asarray([x.min(), x.max()],), bounds[0], atol=1e-5
), "Large numbers are not mapped to the bounds of the bounded transform"
min: 0.0, max: 10.0, to inf: -1.5707963267948966 1.5707963267948966 min: -10000000000.0, max: 10000000000.0, to [ 0 10]: 2.5122629992990753e-06 9.999997487737
# hide
# define NLL functions in both parameter spaces
from neos import models
def make_nll_boundspace(hyperpars):
s, b, db = hyperpars
def nll_boundspace(pars):
truth_pars = [0, 1]
m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))
val = m.logpdf(pars, m.expected_data(truth_pars))
return -val[0]
return nll_boundspace
def make_nll_infspace(hyperpars):
s, b, db = hyperpars
def nll_infspace(pars):
truth_pars = [0, 1]
pars = to_bounded_vec(pars, bounds)
m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))
val = m.logpdf(pars, m.expected_data(truth_pars))
return -val[0]
return nll_infspace
nll_boundspace = make_nll_boundspace([1, 50, 7])
nll_infspace = make_nll_infspace([1, 50, 7])
# define a point and compute it in both spaces
apoint_bnd = jnp.array([0.5, 0.5])
apoint_inf = to_inf_vec(apoint_bnd, bounds)
# check consistency in both spaces
print("check consistency in both spaces:")
point_bound = nll_boundspace(apoint_bnd)
point_inf = nll_infspace(apoint_inf)
assert np.allclose(
point_bound, point_inf
), f"{point_bound} (bounded) should be close to {point_inf} (inf)"
print("..good!")
# check gradients in bounded
print("gradients in bounded space:")
dlb_dpb = jax.grad(nll_boundspace)(apoint_bnd)
print(dlb_dpb)
# check gradients in inf
print("gradients in inf space:")
dli_dinf = jax.grad(nll_infspace)(apoint_inf)
print(dli_dinf)
# check consistency of gradients
print("consistency? check with chain rule:")
dli_dpi = dli_dinf * jnp.array(
[
jax.grad(lambda x, b: to_inf_vec(x, b)[i])(apoint_bnd, bounds)[i]
for i in range(2)
]
)
print(dli_dpi)
# li maps pi to bounded, then becomes lb, so grad should be the same
cond = np.allclose(dli_dpi, dlb_dpb)
assert cond, "Chain rule... doesnt work? :o"
print("all good here chief")
check consistency in both spaces: ..good! gradients in bounded space: [ -0.96078431 -99.05962385] gradients in inf space: [ -2.09398087 -309.31357633] consistency? check with chain rule: [ -0.96078431 -99.05962385] all good here chief
# hide
import scipy
import pyhf
from neos import cls, fit
pyhf.set_backend(pyhf.tensor.jax_backend())
def fit_nll_bounded(init, hyperpars):
mu, model_pars = hyperpars[0], hyperpars[1:]
objective = make_nll_boundspace(model_pars)
return scipy.optimize.minimize(objective, x0=init, bounds=bounds).x
def fit_nll_infspace(init, hyperpars):
mu, model_pars = hyperpars[0], hyperpars[1:]
objective = make_nll_infspace(model_pars)
# result = scipy.optimize.minimize(objective, x0 = init).x
result = funnyscipy.minimize(objective, x0=init)
return to_bounded_vec(result, bounds)
# fit in bounded space
if False:
print("scipy minim in bounded space")
print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 7]))
print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 2]))
print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 1]))
print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 0.1]))
print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 0.01]))
# fit in inf space
if False:
print("scipy minim in inf space")
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 7]))
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 2]))
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 1]))
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.1]))
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.01]))
print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.001]))
def nn_model_maker(nn_params):
s, b, db = nn_params
m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))
nompars = m.config.suggested_init()
bonlypars = jax.numpy.asarray([x for x in nompars])
bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
return m, bonlypars
g_fitter, c_fitter = fit.get_solvers(
nn_model_maker, pdf_transform=True, learning_rate=1e-4
)
bounds = jnp.array([[0.0, 10], [0.0, 10.0]])
if False:
print("diffable minim in inf space")
apoint_bnd = jnp.array([0.5, 0.5])
apoint_inf = to_inf_vec(apoint_bnd, bounds)
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 7.0]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 2.0]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 1.0]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.5]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.1]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.01]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 55, 1.5]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [10, 5, 1.5]]), bounds))
print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [2, 90, 1.5]]), bounds))
print("global fit grad")
print(
jax.value_and_grad(
lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]
)([5.0, 50.0, 15.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]
)([5.0, 50.0, 10.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]
)([5.0, 50.0, 7.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]
)([5.0, 50.0, 1.0])
)
print("constrained!")
apoint_bnd = jnp.array([1.0, 1.0])
apoint_inf = to_inf_vec(apoint_bnd, bounds)
print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 15.0]]), bounds))
print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 10.0]]), bounds))
print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 7.0]]), bounds))
print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 1.0]]), bounds))
print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 0.1]]), bounds))
print("constrained fit grad")
print(
jax.value_and_grad(
lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]
)([5.0, 50.0, 15.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]
)([5.0, 50.0, 10.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]
)([5.0, 50.0, 7.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]
)([5.0, 50.0, 1.0])
)
print(
jax.value_and_grad(
lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]
)([5.0, 50.0, 0.1])
)
def fit_nll_bounded_constrained(init, hyperpars, fixed_val):
mu, model_pars = hyperpars[0], hyperpars[1:]
objective = make_nll_boundspace(model_pars)
return scipy.optimize.minimize(
objective,
x0=init,
bounds=bounds,
constraints=[{"type": "eq", "fun": lambda v: v[0] - fixed_val}],
).x
print("reference")
print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 15.0], 1.0))
print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 10.0], 1.0))
print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 7.0], 1.0))
print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 1.0], 1.0))
print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 0.1], 1.0))
print("diffable cls")
j_cls = []
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([5.0, 50.0, 15.0], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([5.0, 50.0, 10.0], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([5.0, 50.0, 7.0], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([5.0, 50.0, 1.0], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([5.0, 50.0, 0.1], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([10.0, 5.0, 0.1], 1.0)[0]
)
j_cls.append(
jax.value_and_grad(
cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))
)([15.0, 5.0, 0.1], 1.0)[0]
)
print("cross check cls")
def pyhf_cls(nn_params, mu):
s, b, db = nn_params
m = pyhf.simplemodels.hepdata_like([s], [b], [db])
return pyhf.infer.hypotest(1.0, [b] + m.config.auxdata, m)[0]
p_cls = []
p_cls.append(pyhf_cls([5.0, 50.0, 15.0], 1.0))
p_cls.append(pyhf_cls([5.0, 50.0, 10.0], 1.0))
p_cls.append(pyhf_cls([5.0, 50.0, 7.0], 1.0))
p_cls.append(pyhf_cls([5.0, 50.0, 1.0], 1.0))
p_cls.append(pyhf_cls([5.0, 50.0, 0.1], 1.0))
p_cls.append(pyhf_cls([10.0, 5.0, 0.1], 1.0))
p_cls.append(pyhf_cls([15.0, 5.0, 0.1], 1.0))
assert np.allclose(np.asarray(j_cls), np.asarray(p_cls)), "cls values don't match pyhf"
global fit grad (DeviceArray(4.04676292e-13, dtype=float64), [DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)]) (DeviceArray(1.71529457e-13, dtype=float64), [DeviceArray(-6.85560674e-14, dtype=float64), DeviceArray(9.1365743e-15, dtype=float64), DeviceArray(-2.28414357e-14, dtype=float64)]) (DeviceArray(9.93649607e-14, dtype=float64), [DeviceArray(-3.98545918e-14, dtype=float64), DeviceArray(5.99354644e-15, dtype=float64), DeviceArray(-2.87322456e-14, dtype=float64)]) (DeviceArray(6.38378239e-14, dtype=float64), [DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)]) constrained! [1. 0.91976327] [1. 0.93563135] [1. 0.95299143] [1. 0.99821931] [1. 0.99998647] constrained fit grad (DeviceArray(0.91976327, dtype=float64), [DeviceArray(-0.01570885, dtype=float64), DeviceArray(0.00188766, dtype=float64), DeviceArray(-0.00211123, dtype=float64)]) (DeviceArray(0.93563135, dtype=float64), [DeviceArray(-0.01240237, dtype=float64), DeviceArray(0.00169771, dtype=float64), DeviceArray(-0.00457567, dtype=float64)]) (DeviceArray(0.95299143, dtype=float64), [DeviceArray(-0.00890556, dtype=float64), DeviceArray(0.00138767, dtype=float64), DeviceArray(-0.00710027, dtype=float64)]) (DeviceArray(0.99821931, dtype=float64), [DeviceArray(-0.0003251, dtype=float64), DeviceArray(6.73871338e-05, dtype=float64), DeviceArray(-0.00349722, dtype=float64)]) (DeviceArray(0.99998647, dtype=float64), [DeviceArray(-2.78601582e-06, dtype=float64), DeviceArray(4.28295151e-07, dtype=float64), DeviceArray(-0.00022806, dtype=float64)]) reference [1. 0.91979939] [1. 0.93570921] [1. 0.95295097] [1. 0.9982233] [1. 0.99998182] diffable cls cross check cls