import time
import jax
import jax.experimental.optimizers as optimizers
import jax.experimental.stax as stax
import jax.random
from jax.random import PRNGKey
import numpy as np
from functools import partial
import pyhf
pyhf.set_backend("jax")
pyhf.default_backend = pyhf.tensor.jax_backend(precision="64b")
from neos import data, makers
from relaxed import infer
rng = PRNGKey(22)
jax.experimental.stax
¶init_random_params, predict = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(2),
stax.Softmax,
)
dgen = data.generate_blobs(rng, blobs=3)
hmaker = makers.hists_from_nn(dgen, predict, method="softmax", hpar_dict=None)
nnm = makers.hepdata_like_from_hists(hmaker)
get_cls = infer.make_hypotest(nnm, solver_kwargs=dict(pdf_transform=True))
# get_cls returns a list of metrics -- let's just index into the first one (CLs)
def loss(params, test_mu):
return get_cls(params, test_mu)["CLs"]
_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))
# gradient wrt nn weights
jax.value_and_grad(loss)(network, test_mu=1.0)
(DeviceArray(0.07398323, dtype=float64), [(DeviceArray([[-0.01394723, 0.00889332, -0.00755367, ..., -0.01692521, 0.00212111, -0.01970101], [-0.00611163, 0.0067339 , -0.00618019, ..., -0.01004537, 0.00136268, -0.01124781]], dtype=float32), DeviceArray([ 0.0052404 , 0.00297238, 0.00319649, ..., 0.00691849, -0.00113235, -0.00695833], dtype=float32)), (), (DeviceArray([[-5.5783303e-06, -4.1947313e-04, -6.2398164e-04, ..., -1.7087770e-03, 1.4778219e-05, -5.8028835e-04], [ 1.7306319e-04, -7.9908472e-07, 5.1939214e-04, ..., 4.1677014e-04, -1.1140389e-05, 9.1133501e-05], [ 1.3320147e-06, -1.7716239e-04, -1.8571694e-04, ..., -7.1186345e-04, 4.9324703e-06, -2.4198412e-04], ..., [-1.0420149e-05, -7.0664223e-04, -1.0708836e-03, ..., -2.8783218e-03, 2.5193000e-05, -9.7777299e-04], [ 3.6760284e-06, -3.1483595e-04, -3.7518013e-04, ..., -1.2577748e-03, 9.4429461e-06, -4.2865335e-04], [ 2.1981435e-04, 1.4891838e-06, 5.9667905e-04, ..., 6.1379891e-04, -1.2930058e-05, 8.6282926e-05]], dtype=float32), DeviceArray([ 2.1595680e-03, -3.3895427e-03, 1.5025048e-03, ..., -7.7358135e-03, -2.3924757e-05, -3.4947770e-03], dtype=float32)), (), (DeviceArray([[-0.0019205 , 0.0019205 ], [ 0.01192059, -0.01192059], [-0.00670956, 0.00670956], ..., [ 0.01208128, -0.01208128], [-0.00236412, 0.00236412], [ 0.00288121, -0.00288121]], dtype=float32), DeviceArray([-1.2484001e-05, 1.2484001e-05], dtype=float32)), ()])
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
def train_network(N):
cls_vals = []
_, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))
state = opt_init(network)
losses = []
# parameter update function
# @jax.jit
def update_and_value(i, opt_state, mu):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss)(net, mu)
return opt_update(i, grad, state), value, net
for i in range(N):
start_time = time.time()
state, value, network = update_and_value(i, state, 1.0)
epoch_time = time.time() - start_time
losses.append(value)
metrics = {"loss": losses}
yield network, metrics, epoch_time
def plot(axarr, network, metrics, maxN):
ax = axarr[0]
g = np.mgrid[-5:5:101j, -5:5:101j]
levels = np.linspace(0, 1, 21)
ax.contourf(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 2)[:, :, 0],
levels=levels,
cmap="BrBG",
)
ax.contour(
g[0],
g[1],
predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 2)[:, :, 0],
colors="w",
levels=levels,
)
sig, bkg1, bkg2 = dgen()
ax.scatter(sig[:, 0], sig[:, 1], alpha=0.25, c="C9", label="sig")
ax.scatter(bkg1[:, 0], bkg1[:, 1], alpha=0.17, c="C1", label="bkg1")
ax.scatter(bkg2[:, 0], bkg2[:, 1], alpha=0.17, c="C1", label="bkg2")
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax = axarr[1]
ax.axhline(0.05, c="slategray", linestyle="--")
ax.plot(metrics["loss"], c="steelblue", linewidth=2.0)
ax.set_ylim(0, 0.6)
ax.set_xlim(0, maxN)
ax.set_xlabel("epoch")
ax.set_ylabel(r"$cl_s$")
ax = axarr[2]
s, b, db = hmaker(network)
ax.bar([0, 1], b, color="C1", label="bkg")
ax.bar([0, 1], s, bottom=b, color="C9", label="sig")
ax.bar([0, 1], db, bottom=b - db / 2.0, alpha=0.4, color="black", label="bkg error")
ax.set_ylim(0, 100)
ax.set_ylabel("frequency")
ax.set_xlabel("nn output")
# slow
import numpy as np
from IPython.display import HTML
from matplotlib import pyplot as plt
plt.rcParams.update(
{
"axes.labelsize": 13,
"axes.linewidth": 1.2,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"figure.figsize": [13.0, 4.0],
"font.size": 13,
"xtick.major.size": 3,
"ytick.major.size": 3,
"legend.fontsize": 11,
}
)
fig, axarr = plt.subplots(1, 3, dpi=120)
maxN = 50 # make me bigger for better results!
animate = True # animations fail tests...
if animate:
from celluloid import Camera
camera = Camera(fig)
# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]}, took {epoch_time}s')
if animate:
plot(axarr, network, metrics, maxN=maxN)
plt.tight_layout()
camera.snap()
if i % 10 == 0:
camera.animate().save("softmax_animation.gif", writer="imagemagick", fps=8)
# HTML(camera.animate().to_html5_video())
if animate:
camera.animate().save("softmax_animation.gif", writer="imagemagick", fps=8)
epoch 0: CLs = 0.0728174959879968, took 1.5502469539642334s epoch 1: CLs = 0.6033747613050151, took 1.305609941482544s epoch 2: CLs = 0.4772362510826502, took 1.2987580299377441s epoch 3: CLs = 0.303203522165707, took 1.3123400211334229s epoch 4: CLs = 0.17807919527972027, took 1.3043370246887207s epoch 5: CLs = 0.09792412166611708, took 1.3042569160461426s epoch 6: CLs = 0.05045661928240319, took 1.312431812286377s epoch 7: CLs = 0.024775323853546416, took 1.312417984008789s epoch 8: CLs = 0.012006326372574483, took 1.3048739433288574s epoch 9: CLs = 0.005995956652717016, took 1.3154382705688477s epoch 10: CLs = 0.003201696644957508, took 1.304189920425415s epoch 11: CLs = 0.0018730477806756518, took 1.32132887840271s epoch 12: CLs = 0.0012097872459944092, took 1.3092951774597168s epoch 13: CLs = 0.000855258798814873, took 1.312361240386963s epoch 14: CLs = 0.0006508069683333062, took 1.3200039863586426s epoch 15: CLs = 0.0005240430154678233, took 1.3494877815246582s epoch 16: CLs = 0.0004404998669957916, took 1.377411127090454s epoch 17: CLs = 0.00038245353794441606, took 1.3657469749450684s epoch 18: CLs = 0.00034024091764739417, took 1.3846428394317627s epoch 19: CLs = 0.00030837757126023213, took 1.4826250076293945s epoch 20: CLs = 0.0002835917177304026, took 1.6086061000823975s epoch 21: CLs = 0.00026380282667526345, took 1.5852391719818115s epoch 22: CLs = 0.0002476531490986922, took 1.5814290046691895s epoch 23: CLs = 0.00023425802988708, took 1.6054887771606445s epoch 24: CLs = 0.0002229236297943693, took 1.5866038799285889s epoch 25: CLs = 0.00021321015619446548, took 1.5590381622314453s epoch 26: CLs = 0.00020480094785124692, took 1.5385549068450928s epoch 27: CLs = 0.00019743821334361478, took 1.5394530296325684s epoch 28: CLs = 0.0001909311924312984, took 1.4920289516448975s epoch 29: CLs = 0.00018512781060087136, took 1.459669828414917s epoch 30: CLs = 0.00017990690840319346, took 1.4979851245880127s epoch 31: CLs = 0.00017517614911688462, took 1.5110201835632324s epoch 32: CLs = 0.00017085840239827732, took 1.5598220825195312s epoch 33: CLs = 0.00016689837871886049, took 1.4957211017608643s epoch 34: CLs = 0.0001632393686050726, took 1.5260207653045654s epoch 35: CLs = 0.0001598474709885167, took 1.472264051437378s epoch 36: CLs = 0.00015668811114721848, took 1.475229024887085s epoch 37: CLs = 0.0001537304234133785, took 1.4605090618133545s epoch 38: CLs = 0.00015094518410441182, took 1.44688081741333s epoch 39: CLs = 0.0001483019560855059, took 1.4090900421142578s epoch 40: CLs = 0.00014579038403894629, took 1.4354889392852783s epoch 41: CLs = 0.00014338054588325377, took 1.4165687561035156s epoch 42: CLs = 0.00014105497004068823, took 1.456571102142334s epoch 43: CLs = 0.00013886767077475604, took 1.476653814315796s epoch 44: CLs = 0.00013673005949699224, took 1.4481470584869385s epoch 45: CLs = 0.00013470501189205564, took 1.4955050945281982s epoch 46: CLs = 0.00013274539790075757, took 1.4935669898986816s epoch 47: CLs = 0.00013083557800119827, took 1.554062843322754s epoch 48: CLs = 0.0001289809847608847, took 1.5153779983520508s epoch 49: CLs = 0.00012718516979659533, took 1.516117811203003s