import time import jax import jax.experimental.optimizers as optimizers import jax.experimental.stax as stax import jax.random from jax.random import PRNGKey import numpy as np from functools import partial import pyhf pyhf.set_backend("jax") pyhf.default_backend = pyhf.tensor.jax_backend(precision="64b") from neos import data, makers from relaxed import infer rng = PRNGKey(22) init_random_params, predict = stax.serial( stax.Dense(1024), stax.Relu, stax.Dense(1024), stax.Relu, stax.Dense(2), stax.Softmax, ) dgen = data.generate_blobs(rng, blobs=3) hmaker = makers.hists_from_nn(dgen, predict, method="softmax", hpar_dict=None) nnm = makers.hepdata_like_from_hists(hmaker) get_cls = infer.make_hypotest(nnm, solver_kwargs=dict(pdf_transform=True)) # get_cls returns a list of metrics -- let's just index into the first one (CLs) def loss(params, test_mu): return get_cls(params, test_mu)["CLs"] _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2)) # gradient wrt nn weights jax.value_and_grad(loss)(network, test_mu=1.0) opt_init, opt_update, opt_params = optimizers.adam(1e-3) def train_network(N): cls_vals = [] _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2)) state = opt_init(network) losses = [] # parameter update function # @jax.jit def update_and_value(i, opt_state, mu): net = opt_params(opt_state) value, grad = jax.value_and_grad(loss)(net, mu) return opt_update(i, grad, state), value, net for i in range(N): start_time = time.time() state, value, network = update_and_value(i, state, 1.0) epoch_time = time.time() - start_time losses.append(value) metrics = {"loss": losses} yield network, metrics, epoch_time def plot(axarr, network, metrics, maxN): ax = axarr[0] g = np.mgrid[-5:5:101j, -5:5:101j] levels = np.linspace(0, 1, 21) ax.contourf( g[0], g[1], predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 2)[:, :, 0], levels=levels, cmap="BrBG", ) ax.contour( g[0], g[1], predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 2)[:, :, 0], colors="w", levels=levels, ) sig, bkg1, bkg2 = dgen() ax.scatter(sig[:, 0], sig[:, 1], alpha=0.25, c="C9", label="sig") ax.scatter(bkg1[:, 0], bkg1[:, 1], alpha=0.17, c="C1", label="bkg1") ax.scatter(bkg2[:, 0], bkg2[:, 1], alpha=0.17, c="C1", label="bkg2") ax.set_xlim(-5, 5) ax.set_ylim(-5, 5) ax.set_xlabel("x") ax.set_ylabel("y") ax = axarr[1] ax.axhline(0.05, c="slategray", linestyle="--") ax.plot(metrics["loss"], c="steelblue", linewidth=2.0) ax.set_ylim(0, 0.6) ax.set_xlim(0, maxN) ax.set_xlabel("epoch") ax.set_ylabel(r"$cl_s$") ax = axarr[2] s, b, db = hmaker(network) ax.bar([0, 1], b, color="C1", label="bkg") ax.bar([0, 1], s, bottom=b, color="C9", label="sig") ax.bar([0, 1], db, bottom=b - db / 2.0, alpha=0.4, color="black", label="bkg error") ax.set_ylim(0, 100) ax.set_ylabel("frequency") ax.set_xlabel("nn output") # slow import numpy as np from IPython.display import HTML from matplotlib import pyplot as plt plt.rcParams.update( { "axes.labelsize": 13, "axes.linewidth": 1.2, "xtick.labelsize": 13, "ytick.labelsize": 13, "figure.figsize": [13.0, 4.0], "font.size": 13, "xtick.major.size": 3, "ytick.major.size": 3, "legend.fontsize": 11, } ) fig, axarr = plt.subplots(1, 3, dpi=120) maxN = 50 # make me bigger for better results! animate = True # animations fail tests... if animate: from celluloid import Camera camera = Camera(fig) # Training for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)): print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]}, took {epoch_time}s') if animate: plot(axarr, network, metrics, maxN=maxN) plt.tight_layout() camera.snap() if i % 10 == 0: camera.animate().save("softmax_animation.gif", writer="imagemagick", fps=8) # HTML(camera.animate().to_html5_video()) if animate: camera.animate().save("softmax_animation.gif", writer="imagemagick", fps=8)