so many bins :)
import jax
import neos.makers as makers
import neos.cls as cls
import numpy as np
import jax.experimental.stax as stax
import jax.experimental.optimizers as optimizers
import jax.random
import time
import pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())
# regression net
init_random_params, predict = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(1),
stax.Sigmoid
)
bins = np.linspace(0,1,6)
centers = bins[:-1] + np.diff(bins)/2.
bandwidth = .4 * 1/(len(bins)-1)
hmaker = makers.kde_bins_from_nn_three_blobs(predict,bins=bins,bandwidth=bandwidth)
nnm = makers.nn_hepdata_like(hmaker)
loss = cls.cls_maker(nnm, solver_kwargs=dict(pdf_transform=True))
bandwidth
0.08
gradd = jax.jit(jax.value_and_grad(loss))
_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))
gradd(network, 1.0)
(DeviceArray(0.05880111, dtype=float64), [(DeviceArray([[-0.00179044, -0.00056627, 0.00202215, ..., -0.00396901, -0.00175673, -0.00723417], [-0.00017432, -0.00024473, -0.00143037, ..., -0.00237224, -0.00317624, -0.00390435]], dtype=float32), DeviceArray([ 0.00124554, -0.00019967, -0.00139115, ..., -0.00029852, -0.00180919, -0.00188678], dtype=float32)), (), (DeviceArray([[-2.70764849e-05, -6.28665646e-07, -1.20917028e-08, ..., 4.85653118e-06, -8.11458813e-05, 3.06508809e-05], [ 2.83182326e-05, 3.31130068e-05, 4.30183434e-07, ..., 0.00000000e+00, 1.55358535e-07, -2.02215233e-07], [-1.45628455e-05, -2.74914839e-07, -5.74184167e-09, ..., 2.39836436e-06, -6.88106229e-05, 3.14157187e-05], ..., [-3.00191073e-06, 7.77033274e-05, 9.19293257e-07, ..., 1.99842589e-06, -2.74845777e-04, 1.59332907e-04], [ 1.02543192e-04, 1.23361606e-04, 1.59768263e-06, ..., 3.53127803e-08, -5.06841570e-06, 4.16561761e-06], [ 5.91500029e-05, 8.81387532e-05, 1.13218164e-06, ..., 0.00000000e+00, -9.30895840e-06, 6.57422561e-06]], dtype=float32), DeviceArray([ 1.0152139e-03, 1.3619895e-03, 1.7323653e-05, ..., 4.6312110e-05, -1.7732537e-03, 8.8287907e-04], dtype=float32)), (), (DeviceArray([[-0.00040312], [-0.00372409], [-0.00373478], ..., [ 0.000288 ], [ 0.0022565 ], [ 0.00162871]], dtype=float32), DeviceArray([-0.0001226], dtype=float32)), ()])
#jit_loss = jax.jit(loss)
opt_init, opt_update, opt_params = optimizers.adam(3e-3)
#@jax.jit
def update_and_value(i, opt_state, mu):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss)(net, mu)
return opt_update(i, grad, state), value, net
def train_network(N):
cls_vals = []
_, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))
state = opt_init(network)
losses = []
# parameter update function
#@jax.jit
def update_and_value(i, opt_state, mu):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss)(net, mu)
return opt_update(i, grad, state), value, net
for i in range(N):
start_time = time.time()
state, value, network = update_and_value(i,state,1.0)
epoch_time = time.time() - start_time
losses.append(value)
metrics = {"loss": losses}
yield network, metrics, epoch_time
def plot(axarr, network, metrics, hm, maxN):
ax = axarr[0]
g = np.mgrid[-5:5:101j, -5:5:101j]
levels = bins
ax.contourf(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],
levels=levels,
cmap="jet",
)
ax.contour(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],
colors="w",
levels=levels,
)
ax.scatter(hm.sig[:, 0], hm.sig[:, 1], alpha=0.25, c="C9", label="sig")
ax.scatter(hm.bkg1[:, 0], hm.bkg1[:, 1], alpha=0.17, c="C1", label="bkg1")
ax.scatter(hm.bkg2[:, 0], hm.bkg2[:, 1], alpha=0.17, c="C1", label="bkg2")
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax = axarr[1]
ax.axhline(0.05, c="slategray", linestyle="--")
ax.plot(metrics["loss"], c="steelblue", linewidth=2.0)
ax.set_ylim(0, 0.6)
ax.set_xlim(0, maxN)
ax.set_xlabel("epoch")
ax.set_ylabel(r"$cl_s$")
ax = axarr[2]
nn = s, b, db = hm(network)
bin_width = 1/(len(bins)-1)
ax.bar(centers, b, color="C1", label="bkg",width=bin_width)
ax.bar(centers, s, bottom=b, color="C9", label="sig",width=bin_width)
ax.bar(centers, db, bottom=b - db / 2.0, alpha=0.4, color="black", label="bkg error",width=bin_width)
ax.set_ylim(0, 100)
ax.set_ylabel("frequency")
ax.set_xlabel("nn output")
!python -m pip install celluloid
#slow
import numpy as np
from matplotlib import pyplot as plt
from IPython.display import HTML
plt.rcParams.update(
{
"axes.labelsize": 13,
"axes.linewidth": 1.2,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"figure.figsize": [13., 4.0],
"font.size": 13,
"xtick.major.size": 3,
"ytick.major.size": 3,
"legend.fontsize": 11,
}
)
fig, axarr = plt.subplots(1, 3, dpi=120)
maxN = 50 # make me bigger for better results!
animate = False # animations fail tests...
if animate:
from celluloid import Camera
camera = Camera(fig)
# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]}, took {epoch_time}s')
if animate:
plot(axarr, network, metrics, nnm.hm, maxN=maxN)
plt.tight_layout()
camera.snap()
if i % 10 == 0:
camera.animate().save("animation.gif", writer="imagemagick", fps=8)
#HTML(camera.animate().to_html5_video())
if animate:
camera.animate().save("animation.gif", writer="imagemagick", fps=8)
epoch 0: CLs = 0.061524204700417195, took 7.181300163269043s epoch 1: CLs = 0.038272014302923596, took 1.3694071769714355s epoch 2: CLs = 0.0038476321388671852, took 1.3954288959503174s epoch 3: CLs = 0.04351751517576652, took 1.365785837173462s epoch 4: CLs = 0.006314677268905777, took 1.3489091396331787s epoch 5: CLs = 0.0019261697248236231, took 1.3820810317993164s epoch 6: CLs = 0.0028004005358872597, took 1.3562629222869873s epoch 7: CLs = 0.002297108891974986, took 1.3525168895721436s epoch 8: CLs = 0.010675734361591749, took 1.348271131515503s epoch 9: CLs = 0.01199155502521454, took 1.3472709655761719s epoch 10: CLs = 0.00762593726609162, took 1.3535170555114746s epoch 11: CLs = 0.0032174685672372583, took 1.370344877243042s epoch 12: CLs = 0.001166807205252418, took 1.3701229095458984s epoch 13: CLs = 0.0004404550097800719, took 1.3815929889678955s epoch 14: CLs = 0.0004199528500352656, took 1.3516571521759033s epoch 15: CLs = 0.00043441462951632204, took 1.3560807704925537s epoch 16: CLs = 0.00042500279217305703, took 1.341277837753296s epoch 17: CLs = 0.00042511235116404755, took 1.361943006515503s epoch 18: CLs = 0.00043839415488466926, took 1.3481249809265137s epoch 19: CLs = 0.0004592491846160396, took 1.3479270935058594s epoch 20: CLs = 0.0004820709901054432, took 1.3617238998413086s epoch 21: CLs = 0.0005027081616038043, took 1.3702890872955322s epoch 22: CLs = 0.0005197880109817365, took 1.3762280941009521s epoch 23: CLs = 0.0005333018659892108, took 1.3563430309295654s epoch 24: CLs = 0.0005436285357709458, took 1.349862813949585s epoch 25: CLs = 0.000551007436950357, took 1.345383882522583s epoch 26: CLs = 0.000554992768869722, took 1.3663418292999268s epoch 27: CLs = 0.0005553554655504112, took 1.365978717803955s epoch 28: CLs = 0.0005524540322388027, took 1.3484511375427246s epoch 29: CLs = 0.0005461374322499601, took 1.3595550060272217s epoch 30: CLs = 0.000536697891139859, took 1.3438730239868164s epoch 31: CLs = 0.0005243710673876745, took 1.3666539192199707s epoch 32: CLs = 0.0005091909195142907, took 1.3533861637115479s epoch 33: CLs = 0.0004911840439387749, took 1.3540828227996826s epoch 34: CLs = 0.0004710972245902667, took 1.3600499629974365s epoch 35: CLs = 0.0004502861618211895, took 1.3401401042938232s epoch 36: CLs = 0.0004293846931564538, took 1.3844318389892578s epoch 37: CLs = 0.0004090827937917041, took 1.364354133605957s epoch 38: CLs = 0.000389510146778127, took 1.36216402053833s epoch 39: CLs = 0.00037028015231888034, took 1.349194049835205s epoch 40: CLs = 0.0003507029564255859, took 1.3452768325805664s epoch 41: CLs = 0.00032908757139238354, took 1.3685681819915771s epoch 42: CLs = 0.00030342042498654465, took 1.3600702285766602s epoch 43: CLs = 0.0002772329277664909, took 1.3491170406341553s epoch 44: CLs = 0.0002695513717874132, took 1.7477397918701172s epoch 45: CLs = 0.0002922674396699243, took 1.377424955368042s epoch 46: CLs = 0.00030858496496199983, took 1.3887889385223389s epoch 47: CLs = 0.00030070589406672177, took 1.358922004699707s epoch 48: CLs = 0.000279831606425196, took 1.3844869136810303s epoch 49: CLs = 0.0002677123425813832, took 1.4390051364898682s