# default_exp data
Helper module to easily generate example data.
#export
import jax
import jax.numpy as jnp
def generate_blobs(
rng,
blobs,
NMC=500,
sig_mean=jnp.asarray([-1, 1]),
bup_mean=jnp.asarray([2.5, 2]),
bdown_mean=jnp.asarray([-2.5, -1.5]),
b_mean=jnp.asarray([1, -1]),
):
"""
Function that returns a callable to generate a set of 2D normally
distributed blobs, corresponding to signal, background, and background
uncertainty modes.
Args:
rng: jax PRNG key (random seed).
blobs: Number of blobs to generate (3 or 4).
NMC: Number of 'monte carlo' samples to generate.
sig_mean: jax array of the mean of the signal distribution.
bup_mean: jax array of the mean of the 'up' background distribution.
bdown_mean: jax array of the mean of the 'up' background distribution.
b_mean: jax array of the mean of the nominal background distribution.
"""
if blobs == 3:
def gen_blobs():
sig = jax.random.multivariate_normal(
rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
bkg_up = jax.random.multivariate_normal(
rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
bkg_down = jax.random.multivariate_normal(
rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
return sig, bkg_up, bkg_down
elif blobs == 4:
def gen_blobs():
sig = jax.random.multivariate_normal(
rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
bkg_up = jax.random.multivariate_normal(
rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
bkg_down = jax.random.multivariate_normal(
rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
bkg_nom = jax.random.multivariate_normal(
rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)
)
return sig, bkg_nom, bkg_up, bkg_down
else:
assert False, (
f"Unsupported number of blobs: {blobs}"
" (only using 3 or 4 blobs for these examples)."
)
return gen_blobs
import jax
from jax.random import PRNGKey
import neos
from neos.data import generate_blobs
# initialize generator, then call for data
data_gen = generate_blobs(rng=PRNGKey(2), blobs=3, NMC=100)
data = data_gen()
[x.shape for x in data]
[(100, 2), (100, 2), (100, 2)]
Same for 4 blobs, but half the number of samples:
# initialize generator, then call for data
data_gen = generate_blobs(rng=PRNGKey(2), blobs=4, NMC=50)
data = data_gen()
[x.shape for x in data]
[(50, 2), (50, 2), (50, 2), (50, 2)]