import time import jax import jax.experimental.optimizers as optimizers import jax.experimental.stax as stax import jax.random from jax.random import PRNGKey import numpy as np from functools import partial import pyhf pyhf.set_backend('jax') pyhf.default_backend = pyhf.tensor.jax_backend(precision='64b') from neos import data, makers from relaxed import infer rng = PRNGKey(22) # regression net init_random_params, predict = stax.serial( stax.Dense(1024), stax.Relu, stax.Dense(1024), stax.Relu, stax.Dense(1), stax.Sigmoid, ) dgen = data.generate_blobs(rng,blobs=4) # Specify our hyperparameters ahead of time for the kde histograms bins = np.linspace(0,1,4) bandwidth=0.27 reflect_infinite_bins = True hmaker = makers.hists_from_nn(dgen, predict, hpar_dict = dict(bins=bins,bandwidth=bandwidth),method='kde', reflect_infinities=reflect_infinite_bins) nnm = makers.histosys_model_from_hists(hmaker) get_cls = infer.make_hypotest(nnm, solver_kwargs=dict(pdf_transform=True)) # loss returns a list of metrics -- let's just index into one (CLs) def loss(params, test_mu): return get_cls(params, test_mu)["CLs"] _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2)) # gradient wrt nn weights jax.value_and_grad(loss)(network, test_mu=1.0) opt_init, opt_update, opt_params = optimizers.adam(1e-3) def train_network(N): cls_vals = [] _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2)) state = opt_init(network) losses = [] # parameter update function # @jax.jit def update_and_value(i, opt_state, mu): net = opt_params(opt_state) value, grad = jax.value_and_grad(loss)(net, mu) return opt_update(i, grad, state), value, net for i in range(N): start_time = time.time() state, value, network = update_and_value(i, state, 1.0) epoch_time = time.time() - start_time losses.append(value) metrics = {"loss": losses} yield network, metrics, epoch_time def plot(axarr, network, metrics, maxN): ax = axarr[0] g = np.mgrid[-5:5:101j, -5:5:101j] levels = bins ax.contourf( g[0], g[1], predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0], levels=levels, cmap="binary", ) ax.contour( g[0], g[1], predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0], colors="w", levels=levels, ) sig, bkg_nom, bkg_up, bkg_down = dgen() ax.scatter(sig[:, 0], sig[:, 1], alpha=0.3, c="C9") ax.scatter(bkg_up[:, 0], bkg_up[:, 1], alpha=0.1, c="C1", marker=6) ax.scatter(bkg_down[:, 0], bkg_down[:, 1], alpha=0.1, c="C1", marker=7) ax.scatter(bkg_nom[:, 0], bkg_nom[:, 1], alpha=0.3, c="C1") ax.set_xlim(-5, 5) ax.set_ylim(-5, 5) ax.set_xlabel("x") ax.set_ylabel("y") ax = axarr[1] # ax.axhline(0.05, c="slategray", linestyle="--") ax.plot(metrics["loss"], c="steelblue", linewidth=2.0) ax.set_yscale("log") ax.set_ylim(1e-4, 0.06) ax.set_xlim(0, maxN) ax.set_xlabel("epoch") ax.set_ylabel(r"$CL_s$") ax = axarr[2] s, b, bup, bdown = hmaker(network) bin_width = 1 / (len(bins) - 1) centers = bins[:-1] + np.diff(bins) / 2.0 ax.bar(centers, b, color="C1", width=bin_width) ax.bar(centers, s, bottom=b, color="C9", width=bin_width) bunc = np.asarray([[x, y] if x > y else [y, x] for x, y in zip(bup, bdown)]) plot_unc = [] for unc, be in zip(bunc, b): if all(unc > be): plot_unc.append([max(unc), be]) elif all(unc < be): plot_unc.append([be, min(unc)]) else: plot_unc.append(unc) plot_unc = np.asarray(plot_unc) b_up, b_down = plot_unc[:, 0], plot_unc[:, 1] ax.bar(centers, bup - b, bottom=b, alpha=0.4, color="red", width=bin_width, hatch="+") ax.bar( centers, b - bdown, bottom=bdown, alpha=0.4, color="green", width=bin_width, hatch="-" ) ax.set_ylim(0, 120) ax.set_ylabel("frequency") ax.set_xlabel("nn output") # slow import numpy as np from IPython.display import HTML from matplotlib import pyplot as plt plt.rcParams.update( { "axes.labelsize": 13, "axes.linewidth": 1.2, "xtick.labelsize": 13, "ytick.labelsize": 13, "figure.figsize": [13.0, 4.0], "font.size": 13, "xtick.major.size": 3, "ytick.major.size": 3, "legend.fontsize": 11, } ) fig, axarr = plt.subplots(1, 3, dpi=120) maxN = 500 # make me bigger for better results! animate = True # animations fail tests... if animate: from celluloid import Camera camera = Camera(fig) # Training for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)): # print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]}, took {epoch_time}s') if animate: plot(axarr, network, metrics, maxN=maxN) plt.tight_layout() camera.snap() # if i % 10 == 0: # camera.animate().save("animation.gif", writer="imagemagick", fps=8) # HTML(camera.animate().to_html5_video()) # break if animate: camera.animate().save("animationinfesoft.gif", writer="imagemagick", fps=15)