# hide from neos.models import * from neos.makers import * from neos.transforms import * from neos.fit import * from neos.infer import * from neos.smooth import * # bunch of imports: import time import jax import jax.experimental.optimizers as optimizers import jax.experimental.stax as stax import jax.random from jax.random import PRNGKey import numpy as np from functools import partial import pyhf pyhf.set_backend("jax") pyhf.default_backend = pyhf.tensor.jax_backend(precision="64b") from neos import data, infer, makers rng = PRNGKey(22) init_random_params, predict = stax.serial( stax.Dense(1024), stax.Relu, stax.Dense(1024), stax.Relu, stax.Dense(1), stax.Sigmoid, ) # data generator data_gen = data.generate_blobs(rng, blobs=4) # histogram maker hist_maker = makers.hists_from_nn(data_gen, predict, method="kde") # statistical model maker model_maker = makers.histosys_model_from_hists(hist_maker) # CLs value getter get_cls = infer.expected_CLs(model_maker, solver_kwargs=dict(pdf_transform=True)) bins = np.linspace(0, 1, 4) # three bins in the range [0,1] bandwidth = 0.27 # smoothing parameter get_loss = partial(get_cls, hyperparams=dict(bins=bins, bandwidth=bandwidth)) def loss(params, test_mu): return get_loss(params, test_mu)[0] # init weights _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2)) # init optimizer opt_init, opt_update, opt_params = optimizers.adam(1e-3) # define train loop def train_network(N): cls_vals = [] _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2)) state = opt_init(network) losses = [] # parameter update function def update_and_value(i, opt_state, mu): net = opt_params(opt_state) value, grad = jax.value_and_grad(loss)(net, mu) return opt_update(i, grad, state), value, net for i in range(N): start_time = time.time() state, value, network = update_and_value(i, state, 1.0) epoch_time = time.time() - start_time losses.append(value) metrics = {"loss": losses} yield network, metrics, epoch_time maxN = 10 # make me bigger for better results (*nearly* true ;]) for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)): print(f"epoch {i}:", f'CLs = {metrics["loss"][-1]:.5f}, took {epoch_time:.4f}s')