One might be interested in optimizing for two "compteting" models at the same time. Consider having 3 separate samples A, B, C and we'd be interesting in extracting the significance for two out of the three at the same time. Two models would be fitted, e.g one where A is signal and B & C are backgrounds and one where B is signal and A & C are backgrounds. This example shows how to optimize for both of them at the same time.
import time
import jax
import jax.experimental.optimizers as optimizers
import jax.experimental.stax as stax
import jax.random
from jax.random import PRNGKey
import numpy as np
from functools import partial
from neos import data, makers
from relaxed import infer
rng = PRNGKey(22)
def hists_from_nn_three_samples(
predict,
NMC=500,
s1_mean=[-2, 2],
s2_mean=[2, 2],
s3_mean=[0, -2],
LUMI=10,
sig_scale=1,
bkg_scale=1,
group=1,
real_z=False,
):
'''
Same as hists_from_nn_three_blobs, but parametrize grouping of signal and background for
three separatate samples
Args:
predict: Decision function for a parameterized observable. Assumed softmax here.
Returns:
hist_maker: A callable function that takes the parameters of the observable,
then constructs signal, background, and background uncertainty yields.
'''
def get_hists(network, s, bs):
NMC = len(s)
s_hist = predict(network, s).sum(axis=0) * sig_scale / NMC * LUMI
b_hists = tuple([
predict(network, bs[0]).sum(axis=0) * sig_scale / NMC * LUMI,
predict(network, bs[1]).sum(axis=0) * bkg_scale / NMC * LUMI
])
b_tot = jax.numpy.sum(jax.numpy.asarray(b_hists), axis=0)
b_unc = jax.numpy.sqrt(b_tot)
# append raw hists for signal and bkg as well
results = s_hist, b_tot, b_unc, s_hist, b_hists
return results
def hist_maker():
sig1 = np.random.multivariate_normal(s1_mean, [[1, 0], [0, 1]], size=(NMC,))
sig2 = np.random.multivariate_normal(s2_mean, [[1, 0], [0, 1]], size=(NMC,))
bkg = np.random.multivariate_normal(s3_mean, [[1, 0], [0, 1]], size=(NMC,))
def make(network):
if group == 1:
return get_hists(network, sig1, (sig2, bkg))
elif group == 2:
return get_hists(network, sig2, (sig1, bkg))
elif group == 3:
return get_hists(network, bkg, (sig1, sig2))
else:
raise UserWarning
make.bkg = bkg
make.sig2 = sig2
make.sig1 = sig1
return make
return hist_maker
import pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())
from neos import models
def nn_hepdata_like_w_hists(histogram_maker):
'''
Analogous function to `makers.nn_hepdata_like`, but modified to pass through
the additional info added in hists_from_nn_three_samples.
'''
hm = histogram_maker()
def nn_model_maker(hpars):
network = hpars
s, b, db, _, _ = hm(network) # Changed here
m = models.hepdata_like(s, b, db) # neos model
nompars = m.config.suggested_init()
bonlypars = jax.numpy.asarray([x for x in nompars])
bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
return m, bonlypars
nn_model_maker.hm = hm
return nn_model_maker
jax.experimental.stax
¶NOUT = 3
init_random_params, predict = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(NOUT),
stax.Softmax,
)
hmaker = hists_from_nn_three_samples(predict, group=1)
nnm = nn_hepdata_like_w_hists(hmaker)
hmaker2 = hists_from_nn_three_samples(predict, group=2)
nnm2 = nn_hepdata_like_w_hists(hmaker2)
loss1 = infer.make_hypotest(nnm, solver_kwargs=dict(pdf_transform=True))
loss2 = infer.make_hypotest(nnm2, solver_kwargs=dict(pdf_transform=True))
# optimize the average significance!
loss = lambda params, test_mu: (loss1(params, test_mu)['CLs'] + loss2(params, test_mu)['CLs'])/2
_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))
loss(network,1.)
DeviceArray(0.05143708, dtype=float32)
nnm.hm(network)
(DeviceArray([3.184768 , 3.3555813, 3.4596505], dtype=float32), DeviceArray([6.6250687, 6.9932375, 6.381694 ], dtype=float32), DeviceArray([2.573921 , 2.644473 , 2.5262015], dtype=float32), DeviceArray([3.184768 , 3.3555813, 3.4596505], dtype=float32), (DeviceArray([3.1510186, 3.7200375, 3.1289444], dtype=float32), DeviceArray([3.47405 , 3.2732 , 3.2527497], dtype=float32)))
a, b, c, d, e = nnm.hm(network)
#jit_loss = jax.jit(loss)
opt_init, opt_update, opt_params = optimizers.adam(.5e-3)
def train_network(N, cont=False, network=None):
if not cont:
_, network = init_random_params(jax.random.PRNGKey(4), (-1, 2))
if network is not None:
network = network
losses = []
cls_vals = []
state = opt_init(network)
# parameter update function
#@jax.jit
def update_and_value(i, opt_state, mu, loss_choice):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss_choice)(net, mu)
return opt_update(i, grad, state), value, net
for i in range(N):
start_time = time.time()
loss_choice = loss
state, value, network = update_and_value(i,state,1.0, loss_choice)
epoch_time = time.time() - start_time
losses.append(value)
metrics = {"loss": losses}
yield network, metrics, epoch_time
# Choose colormap
import matplotlib.pylab as pl
from matplotlib.colors import ListedColormap
def to_transp(cmap):
#cmap = pl.cm.Reds_r
my_cmap = cmap(np.arange(cmap.N))
#my_cmap[:,-1] = np.geomspace(0.001, 1, cmap.N)
my_cmap[:,-1] = np.linspace(0, 0.7, cmap.N)
#my_cmap[:,-1] = np.ones(cmap.N)
return ListedColormap(my_cmap)
def plot(axarr, network, metrics, hm, hm2, maxN, ith):
xlim = (-5, 5)
ylim = (-5, 5)
g = np.mgrid[xlim[0]:xlim[1]:101j, ylim[0]:ylim[1]:101j]
levels = np.linspace(0, 1, 20)
ax = axarr[0]
ax.contourf(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, NOUT)[:, :, 0],
levels=levels,
cmap = to_transp(pl.cm.Reds),
)
ax.contourf(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, NOUT)[:, :, 1],
levels=levels,
cmap = to_transp(pl.cm.Greens),
)
if NOUT > 2:
ax.contourf(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 3)[:, :, 2],
levels=levels,
cmap = to_transp(pl.cm.Blues),
)
#print(list(map(len, [hm.sig1[:, 0], hm.sig2[:, 0], hm.bkg[:, 0]])))
ax.scatter(hm.sig1[:, 0], hm.sig1[:, 1], alpha=0.25, c="C9", label="sig1")
ax.scatter(hm.sig2[:, 0], hm.sig2[:, 1], alpha=0.17, c="C8", label="bkg2")
ax.scatter(hm.bkg[:, 0], hm.bkg[:, 1], alpha=0.17, c="C1", label="bkg2")
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax = axarr[1]
ax.axhline(0.05, c="slategray", linestyle="--")
ax.plot(metrics["loss"][:ith], c="steelblue", linewidth=2.0)
ax.set_ylim(0, metrics["loss"][0])
ax.set_xlim(0, maxN)
ax.set_xlabel("epoch")
ax.set_ylabel(r"$cl_s$")
ax = axarr[2]
s, b, db, sig, bs = hm(network)
ytop = np.max(np.sum([s, b], axis=0))*1.3
ax.bar(range(NOUT), sig, bottom=bs[0]+bs[1], color="C9", label="Sample 1")
ax.bar(range(NOUT), bs[0], bottom=bs[1], color="C8", label="Sample 2")
ax.bar(range(NOUT), bs[1], color="C1", label="Sample 3")
ax.set_ylabel("frequency")
ax.set_xlabel("nn output")
ax.set_title("Raw histograms")
ax.set_ylim(0, ytop)
if ith == 0:
ax.legend()
ax = axarr[3]
s, b, db, sig, bs = hm(network)
ax.bar(range(NOUT), s, bottom=b, color="#722620", label="sig", alpha=0.9)
ax.bar(range(NOUT), b, color="#F2BC94", label="bkg")
ax.bar(range(NOUT), db, bottom=b - db / 2.0, alpha=0.3, color="black", label="bkg error", hatch='////')
ax.set_ylabel("frequency")
ax.set_xlabel("nn output")
ax.set_title("Model 1: sig1 vs (sig2 + bkg)")
ax.set_ylim(0, ytop)
if ith == 0:
ax.legend()
ax = axarr[4]
s, b, db, sig, bs = hm2(network)
ax.bar(range(NOUT), s, bottom=b, color="#722620", label="sig")
ax.bar(range(NOUT), b, color="#F2BC94", label="bkg")
ax.bar(range(NOUT), db, bottom=b - db / 2.0, alpha=0.3, color="black", label="bkg error", hatch='////')
ax.set_ylabel("frequency")
ax.set_xlabel("nn output")
ax.set_title("Model 2: sig2 vs (sig1 + bkg)")
ax.set_ylim(0, ytop)
if ith == 0:
ax.legend()
import numpy as np
from matplotlib import pyplot as plt
from celluloid import Camera
from IPython.display import HTML
plt.rcParams.update(
{
"axes.labelsize": 13,
"axes.linewidth": 1.2,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"figure.figsize": [12.0, 8.0],
"font.size": 13,
"xtick.major.size": 3,
"ytick.major.size": 3,
"legend.fontsize": 11,
}
)
fig, axarr = plt.subplots(2, 3, dpi=120)
axarr = axarr.flatten()
# fig.set_size_inches(15, 10)
camera = Camera(fig)
maxN = 20 # make me bigger for better results!
animate = True # animations fail tests
# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]}, took {epoch_time}s')
if animate:
plot(axarr, network, metrics, nnm.hm, nnm2.hm, maxN=maxN, ith=i)
plt.tight_layout()
camera.snap()
if animate:
camera.animate().save("animation.gif", writer="imagemagick", fps=10)
epoch 0: CLs = 0.05124938488006592, took 7.84034276008606s epoch 1: CLs = 0.038916826248168945, took 0.24336791038513184s epoch 2: CLs = 0.02349996566772461, took 0.23299694061279297s epoch 3: CLs = 0.011642754077911377, took 0.24342584609985352s epoch 4: CLs = 0.005460619926452637, took 0.2316129207611084s epoch 5: CLs = 0.002736389636993408, took 0.23945212364196777s epoch 6: CLs = 0.0015284419059753418, took 0.292529821395874s epoch 7: CLs = 0.0009472370147705078, took 0.22985506057739258s epoch 8: CLs = 0.000640869140625, took 0.23697280883789062s epoch 9: CLs = 0.0004658699035644531, took 0.2311398983001709s epoch 10: CLs = 0.00035893917083740234, took 0.2465968132019043s epoch 11: CLs = 0.00028961896896362305, took 0.23459291458129883s epoch 12: CLs = 0.0002422928810119629, took 0.23319482803344727s epoch 13: CLs = 0.00020873546600341797, took 0.2266685962677002s epoch 14: CLs = 0.00018399953842163086, took 0.24457406997680664s epoch 15: CLs = 0.0001652836799621582, took 0.22907018661499023s epoch 16: CLs = 0.0001507401466369629, took 0.3017148971557617s epoch 17: CLs = 0.0001392960548400879, took 0.22998595237731934s epoch 18: CLs = 0.00013005733489990234, took 0.2310020923614502s epoch 19: CLs = 0.00012248754501342773, took 0.22712182998657227s