neos
:¶We first need to install a special branch of pyhf
:
%pip install git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor
Start with a couple of imports:
from jax.example_libraries import stax # neural network library for JAX
from jax.random import PRNGKey # random number generator
import jax.numpy as jnp # JAX's numpy
import neos # :)
neos
experiments have been designed to run through a flexible Pipeline
class, which will compose the necessary ingredients to train differentiable analyses end-to-end.
We have other examples in the works, but for now, we have wrapped up our current experiments in a module called nn_observable
:
from neos.experiments.nn_observable import (
nn_summary_stat, # create a summary statistic from a neural network
make_model, # use the summary statistic to make a HistFactory style model
generate_data, # generates gaussian blobs to feed into the nn
first_epoch, # special plotting callback for the first epoch
last_epoch, # special plotting callback for the last epoch
per_epoch, # generic plotting callback for each epoch
plot_setup, # inital setup for the plotting
)
Each of these functions are pretty lightweight (with the exception of the plotting) -- if you want to get experimental and write your own pipeline, you'll find the code for those functions as a good starting point!
Now we'll jump into training! First, we set up a neural network (for regression) and a random state:
rng_state = 0 # random state
# feel free to modify :)
init_random_params, nn = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(1),
stax.Sigmoid,
)
_, init_pars = init_random_params(PRNGKey(rng_state), (-1, 2))
From there, we compose our pipeline with the relevant ingredients. I'll point out things you can play with immediately:
p = neos.Pipeline(
data=generate_data(rng=rng_state, num_points=10000), # total number of points
yield_kwargs=dict(
nn=nn, # the nn we defined above
bandwidth=1e-1, # bandwidth of the KDE (lower = more like a real histogram)
bins=jnp.linspace(0, 1, 5), # binning of the summary stat (over [0,1])
),
loss=lambda x: x["CLs"],
num_epochs=10, # number of epochs
batch_size=500, # number of points per batch
plotname="demo_nn_observable.png", # save the final plot!
animate=True, # make cool animations!
animationname="demo_nn_observable.gif", # save them!
random_state=rng_state,
yields_from_pars=nn_summary_stat,
model_from_yields=make_model,
init_pars=init_pars,
first_epoch_callback=first_epoch,
last_epoch_callback=last_epoch,
per_epoch_callback=per_epoch,
plot_setup=plot_setup,
)
Then we run! Each epoch takes around 15s on my local CPU, so expect something similar :)
You'll see some cool plots and animations, so it's worth it ;)
p.run()