neos
:¶We first need to install a special branch of pyhf
:
%pip install git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor
Collecting git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor Cloning http://github.com/scikit-hep/pyhf.git (to revision make_difffable_model_ctor) to /tmp/pip-req-build-ol32u9ie Running command git clone --filter=blob:none --quiet http://github.com/scikit-hep/pyhf.git /tmp/pip-req-build-ol32u9ie warning: redirecting to https://github.com/scikit-hep/pyhf.git/ Running command git checkout -b make_difffable_model_ctor --track origin/make_difffable_model_ctor ^C ERROR: Operation cancelled by user Note: you may need to restart the kernel to use updated packages.
Start with a couple of imports:
from jax.example_libraries import stax # neural network library for JAX
from jax.random import PRNGKey # random number generator
import jax.numpy as jnp # JAX's numpy
import neos # :)
neos
experiments have been designed to run through a flexible Pipeline
class, which will compose the necessary ingredients to train differentiable analyses end-to-end.
We have other examples in the works, but for now, we have wrapped up our current experiments in a module called nn_observable
:
from neos.experiments.nn_observable import (
nn_summary_stat, # create a summary statistic from a neural network
make_model, # use the summary statistic to make a HistFactory style model
generate_data, # generates gaussian blobs to feed into the nn
first_epoch, # special plotting callback for the first epoch
last_epoch, # special plotting callback for the last epoch
per_epoch, # generic plotting callback for each epoch
plot_setup, # inital setup for the plotting
)
Each of these functions are pretty lightweight (with the exception of the plotting) -- if you want to get experimental and write your own pipeline, you'll find the code for those functions as a good starting point!
Now we'll jump into training! First, we set up a neural network (for regression) and a random state:
rng_state = 0 # random state
# feel free to modify :)
init_random_params, nn = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(1),
stax.Sigmoid,
)
_, init_pars = init_random_params(PRNGKey(rng_state), (-1, 2))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
From there, we compose our pipeline with the relevant ingredients. I'll point out things you can play with immediately:
p = neos.Pipeline(
data=generate_data(rng=rng_state, num_points=10000), # total number of points
yield_kwargs=dict(
bandwidth=1e-1, # bandwidth of the KDE (lower = more like a real histogram)
bins=jnp.linspace(0, 1, 5), # binning of the summary stat (over [0,1])
),
nn=nn, # the nn we defined above
loss=lambda x: x[
"CLs" # our chosen loss metric!
], # you can compose your own loss from a dict of metrics (see p.possible_metrics)
num_epochs=5, # number of epochs
batch_size=2000, # number of points per batch
plot_name="demo_nn_observable.png", # save the final plot!
animate=True, # make cool animations!
animation_name="demo_nn_observable.gif", # save them!
random_state=rng_state,
yields_from_pars=nn_summary_stat,
model_from_yields=make_model,
init_pars=init_pars,
first_epoch_callback=first_epoch,
last_epoch_callback=last_epoch,
per_epoch_callback=per_epoch,
plot_setup=plot_setup,
)
Then we run! Each epoch takes around 15s on my local CPU, so expect something similar :)
You'll see some cool plots and animations, so it's worth it ;)
p.run()
epoch 4/5: 4 batches batch 4/4 took 5.0345s. batch loss: 9.36e-06 metrics evaluated on test set: yields: s = [1.84, 2.31, 5.04, 10.8] b = [90.3, 6.25, 3, 0.4] bup = [72.4, 13.1, 9.3, 5.15] bdown = [72, 14.5, 9.85, 3.75] CLs = 1.05e-06 mu_uncert = 0.0795 pull_width = 0.906 1-pull_width**2 = 0.00877 pull = -0.00898 loss = 1.05e-06
WARNING:matplotlib.animation:MovieWriter imagemagick unavailable; using Pillow instead.
(<matplotlib.animation.ArtistAnimation at 0x7f1f8c603c70>, {'CLs': [DeviceArray(2.35287632e-05, dtype=float64), DeviceArray(1.70566834e-06, dtype=float64), DeviceArray(1.11617354e-06, dtype=float64), DeviceArray(1.08438735e-06, dtype=float64), DeviceArray(1.97689571e-06, dtype=float64), DeviceArray(3.39126491e-06, dtype=float64), DeviceArray(3.25749298e-06, dtype=float64), DeviceArray(3.22669855e-06, dtype=float64), DeviceArray(3.14261578e-06, dtype=float64), DeviceArray(3.12066055e-06, dtype=float64), DeviceArray(2.90500077e-06, dtype=float64), DeviceArray(2.84339593e-06, dtype=float64), DeviceArray(2.31737151e-06, dtype=float64), DeviceArray(1.39411512e-06, dtype=float64), DeviceArray(1.07047366e-06, dtype=float64), DeviceArray(9.87367532e-07, dtype=float64), DeviceArray(9.80170437e-07, dtype=float64), DeviceArray(9.47542831e-07, dtype=float64), DeviceArray(1.07959525e-06, dtype=float64), DeviceArray(1.04620676e-06, dtype=float64)], 'mu_uncert': [DeviceArray(0.18044434, dtype=float64), DeviceArray(0.34672473, dtype=float64), DeviceArray(0.36839304, dtype=float64), DeviceArray(0.37000935, dtype=float64), DeviceArray(0.35839486, dtype=float64), DeviceArray(0.33573236, dtype=float64), DeviceArray(0.3013236, dtype=float64), DeviceArray(0.26116655, dtype=float64), DeviceArray(0.22139571, dtype=float64), DeviceArray(0.18284549, dtype=float64), DeviceArray(0.13734836, dtype=float64), DeviceArray(0.11403537, dtype=float64), DeviceArray(0.09754423, dtype=float64), DeviceArray(0.08339823, dtype=float64), DeviceArray(0.07825757, dtype=float64), DeviceArray(0.08114298, dtype=float64), DeviceArray(0.07252708, dtype=float64), DeviceArray(0.08343542, dtype=float64), DeviceArray(0.08273326, dtype=float64), DeviceArray(0.07947763, dtype=float64)], '1-pull_width**2': [DeviceArray(0.30723013, dtype=float64), DeviceArray(0.143446, dtype=float64), DeviceArray(0.11617264, dtype=float64), DeviceArray(0.12309279, dtype=float64), DeviceArray(0.08922027, dtype=float64), DeviceArray(0.07001359, dtype=float64), DeviceArray(0.04911868, dtype=float64), DeviceArray(0.03050092, dtype=float64), DeviceArray(0.00893807, dtype=float64), DeviceArray(0.0055946, dtype=float64), DeviceArray(0.00051842, dtype=float64), DeviceArray(0.00017573, dtype=float64), DeviceArray(4.74777157e-06, dtype=float64), DeviceArray(0.00173089, dtype=float64), DeviceArray(0.0077282, dtype=float64), DeviceArray(0.01848314, dtype=float64), DeviceArray(0.00245099, dtype=float64), DeviceArray(0.00624937, dtype=float64), DeviceArray(0.00488906, dtype=float64), DeviceArray(0.00877015, dtype=float64)], 'loss': [DeviceArray(0.06001607, dtype=float64), DeviceArray(0.00038053, dtype=float64), DeviceArray(2.41939008e-05, dtype=float64), DeviceArray(1.85273245e-05, dtype=float64), DeviceArray(2.43911782e-05, dtype=float64), DeviceArray(4.25606804e-05, dtype=float64), DeviceArray(4.49382958e-05, dtype=float64), DeviceArray(8.03035744e-05, dtype=float64), DeviceArray(7.42985366e-05, dtype=float64), DeviceArray(5.15046618e-05, dtype=float64), DeviceArray(6.88392294e-05, dtype=float64), DeviceArray(5.9455098e-05, dtype=float64), DeviceArray(4.92457398e-05, dtype=float64), DeviceArray(3.88366601e-05, dtype=float64), DeviceArray(3.16286162e-05, dtype=float64), DeviceArray(2.55299712e-05, dtype=float64), DeviceArray(1.25133961e-05, dtype=float64), DeviceArray(2.19693968e-05, dtype=float64), DeviceArray(1.48435742e-05, dtype=float64), DeviceArray(9.35913565e-06, dtype=float64)], 'test_loss': [DeviceArray(2.35287632e-05, dtype=float64), DeviceArray(1.70566834e-06, dtype=float64), DeviceArray(1.11617354e-06, dtype=float64), DeviceArray(1.08438735e-06, dtype=float64), DeviceArray(1.97689571e-06, dtype=float64), DeviceArray(3.39126491e-06, dtype=float64), DeviceArray(3.25749298e-06, dtype=float64), DeviceArray(3.22669855e-06, dtype=float64), DeviceArray(3.14261578e-06, dtype=float64), DeviceArray(3.12066055e-06, dtype=float64), DeviceArray(2.90500077e-06, dtype=float64), DeviceArray(2.84339593e-06, dtype=float64), DeviceArray(2.31737151e-06, dtype=float64), DeviceArray(1.39411512e-06, dtype=float64), DeviceArray(1.07047366e-06, dtype=float64), DeviceArray(9.87367532e-07, dtype=float64), DeviceArray(9.80170437e-07, dtype=float64), DeviceArray(9.47542831e-07, dtype=float64), DeviceArray(1.07959525e-06, dtype=float64), DeviceArray(1.04620676e-06, dtype=float64)], 'pull': [DeviceArray(0.00022655, dtype=float64), DeviceArray(0.05720803, dtype=float64), DeviceArray(0.06540602, dtype=float64), DeviceArray(0.06656146, dtype=float64), DeviceArray(0.06069352, dtype=float64), DeviceArray(0.05314889, dtype=float64), DeviceArray(0.04475, dtype=float64), DeviceArray(0.03678646, dtype=float64), DeviceArray(0.02953886, dtype=float64), DeviceArray(0.02164761, dtype=float64), DeviceArray(0.01385489, dtype=float64), DeviceArray(0.00788449, dtype=float64), DeviceArray(0.00100708, dtype=float64), DeviceArray(0.00174979, dtype=float64), DeviceArray(-0.00092322, dtype=float64), DeviceArray(-0.00235131, dtype=float64), DeviceArray(-0.00183691, dtype=float64), DeviceArray(-0.00774666, dtype=float64), DeviceArray(-0.00998592, dtype=float64), DeviceArray(-0.00897645, dtype=float64)], 'epoch_grid': [], 'pars': {'post-epoch-0': [(DeviceArray([[ 0.03461211, 0.08963372, 0.00074539, ..., -0.03061933, -0.03610808, -0.03041384], [ 0.0715813 , 0.02988483, 0.00811483, ..., -0.05734253, -0.05504322, -0.00157194]], dtype=float32), DeviceArray([ 0.01778431, -0.00402131, -0.00365581, ..., 0.0119588 , -0.00207934, -0.00749448], dtype=float32)), (), (DeviceArray([[ 0.00173731, 0.00430168, 0.00511208, ..., -0.03397341, -0.00374931, -0.01860504], [-0.02090159, 0.04028343, 0.05571346, ..., 0.04406765, -0.01242392, 0.06297275], [ 0.02269247, -0.00992602, -0.04954783, ..., 0.00636522, -0.03133326, -0.05437366], ..., [-0.00065288, 0.01556657, -0.0116964 , ..., 0.03372836, 0.00836964, -0.00823315], [-0.01722111, -0.02237464, -0.01559007, ..., 0.02027703, 0.06176555, 0.05073778], [-0.01017749, 0.02014277, -0.04629378, ..., -0.0658423 , -0.06948105, 0.0396019 ]], dtype=float32), DeviceArray([-0.00349353, -0.00035816, 0.01397473, ..., -0.00181315, -0.0241263 , 0.00324436], dtype=float32)), (), (DeviceArray([[-0.04651356], [ 0.03184702], [-0.05227218], ..., [ 0.00777215], [ 0.02719924], [-0.03107622]], dtype=float32), DeviceArray([0.00653545], dtype=float32)), ()], 'post-epoch-1': [(DeviceArray([[ 0.03300948, 0.09120078, -0.00025366, ..., -0.03095743, -0.03790518, -0.03191839], [ 0.07303267, 0.02846871, 0.00905005, ..., -0.05957199, -0.05401719, 0.00013198]], dtype=float32), DeviceArray([ 0.01737212, -0.00212188, -0.0045023 , ..., 0.01405519, -0.00257346, -0.00600346], dtype=float32)), (), (DeviceArray([[ 0.00255658, 0.00591386, 0.00367625, ..., -0.03399411, -0.00374931, -0.0193642 ], [-0.01916189, 0.03865325, 0.05736209, ..., 0.04246663, -0.01242392, 0.06470893], [ 0.02265702, -0.0084736 , -0.05114403, ..., 0.00723029, -0.03133326, -0.05437803], ..., [ 0.00102803, 0.01415336, -0.01009653, ..., 0.03248042, 0.00702963, -0.00649198], [-0.01559785, -0.02374438, -0.01402571, ..., 0.01906815, 0.06043033, 0.052417 ], [-0.01153067, 0.02161124, -0.04790483, ..., -0.06622753, -0.07081573, 0.03856871]], dtype=float32), DeviceArray([-0.00153802, -0.00188186, 0.01623154, ..., -0.00339514, -0.02548816, 0.0051833 ], dtype=float32)), (), (DeviceArray([[-0.04842605], [ 0.03042445], [-0.05111502], ..., [ 0.00637248], [ 0.02598962], [-0.03296407]], dtype=float32), DeviceArray([0.00388367], dtype=float32)), ()], 'post-epoch-2': [(DeviceArray([[ 3.2281853e-02, 9.2181161e-02, -2.1387808e-05, ..., -3.2303683e-02, -3.9687011e-02, -3.2599136e-02], [ 7.3500268e-02, 2.7700869e-02, 8.3445786e-03, ..., -6.1134338e-02, -5.3452242e-02, 1.0178870e-03]], dtype=float32), DeviceArray([ 0.01579907, -0.00055085, -0.00710869, ..., 0.01731742, -0.00261137, -0.00539406], dtype=float32)), (), (DeviceArray([[ 0.00335429, 0.00619762, 0.00285042, ..., -0.03448601, -0.00374931, -0.01966722], [-0.01807399, 0.0376292 , 0.05838328, ..., 0.04147416, -0.01242392, 0.06582875], [ 0.02252917, -0.00815426, -0.0521327 , ..., 0.00746765, -0.03133326, -0.0544034 ], ..., [ 0.00214813, 0.01318577, -0.00904041, ..., 0.03171881, 0.00621027, -0.00533723], [-0.01452005, -0.02468678, -0.01299596, ..., 0.01833109, 0.05961392, 0.05352702], [-0.01214267, 0.02214357, -0.04881328, ..., -0.06711706, -0.07162184, 0.03808565]], dtype=float32), DeviceArray([-0.00018717, -0.00326605, 0.01864611, ..., -0.00446187, -0.02632264, 0.00655048], dtype=float32)), (), (DeviceArray([[-0.04973336], [ 0.02930642], [-0.05045142], ..., [ 0.00543852], [ 0.02524832], [-0.03422985]], dtype=float32), DeviceArray([0.00060639], dtype=float32)), ()], 'post-epoch-3': [(DeviceArray([[ 0.03194332, 0.09291041, 0.00048266, ..., -0.03374135, -0.04123077, -0.03294481], [ 0.07342307, 0.0273406 , 0.0070903 , ..., -0.06178537, -0.05303269, 0.00146356]], dtype=float32), DeviceArray([ 0.01381254, 0.00085352, -0.00981413, ..., 0.02060601, -0.00247614, -0.00514659], dtype=float32)), (), (DeviceArray([[ 0.00437782, 0.0058378 , 0.00236915, ..., -0.03507716, -0.00374931, -0.01970195], [-0.01731944, 0.03690399, 0.05910962, ..., 0.04080106, -0.01242392, 0.06662571], [ 0.02241883, -0.00822063, -0.05276915, ..., 0.00751226, -0.03133326, -0.0544284 ], ..., [ 0.00288264, 0.01252365, -0.00833978, ..., 0.03121953, 0.0056723 , -0.00458386], [-0.01381773, -0.02533099, -0.01231753, ..., 0.01784828, 0.05907791, 0.05424607], [-0.01256784, 0.02227502, -0.04940869, ..., -0.06789716, -0.0721459 , 0.03773886]], dtype=float32), DeviceArray([ 0.0007551 , -0.00445799, 0.02079784, ..., -0.00520574, -0.02687146, 0.00751692], dtype=float32)), (), (DeviceArray([[-0.05060438], [ 0.02838052], [-0.05004761], ..., [ 0.00478811], [ 0.02476073], [-0.03508484]], dtype=float32), DeviceArray([-0.00260078], dtype=float32)), ()], 'post-epoch-4': [(DeviceArray([[ 0.03178716, 0.09343064, 0.00093899, ..., -0.03486153, -0.04236264, -0.03315259], [ 0.07320864, 0.02715596, 0.00601478, ..., -0.0619179 , -0.0526996 , 0.0017091 ]], dtype=float32), DeviceArray([ 0.01208899, 0.00197776, -0.01187852, ..., 0.02333199, -0.00233972, -0.00502234], dtype=float32)), (), (DeviceArray([[ 0.00539733, 0.00537654, 0.00206398, ..., -0.03553127, -0.00374931, -0.01962758], [-0.0167809 , 0.03640136, 0.05961957, ..., 0.04034355, -0.01242392, 0.0672012 ], [ 0.02234982, -0.0083178 , -0.0531976 , ..., 0.00752538, -0.03133326, -0.05444454], ..., [ 0.00336435, 0.01207753, -0.00786968, ..., 0.03088279, 0.00530899, -0.00408948], [-0.01335931, -0.02576389, -0.01186466, ..., 0.01752286, 0.05871592, 0.05471517], [-0.01291023, 0.02227196, -0.04981849, ..., -0.06842809, -0.07249686, 0.03746387]], dtype=float32), DeviceArray([ 0.00140617, -0.00536992, 0.02245046, ..., -0.00571594, -0.02724266, 0.00819512], dtype=float32)), (), (DeviceArray([[-0.05118783], [ 0.02766497], [-0.04979014], ..., [ 0.00434128], [ 0.02443091], [-0.03566821]], dtype=float32), DeviceArray([-0.00523573], dtype=float32)), ()]}, 'yields': [array([ 1.84, 2.31, 5.04, 10.81]), array([90.35, 6.25, 3. , 0.4 ]), array([72.4 , 13.15, 9.3 , 5.15]), array([71.95, 14.45, 9.85, 3.75])], 'pull_width': DeviceArray(0.90635093, dtype=float64)})