Let's look at a problem that looks just a bit like machine learning: Curve fitting for unbinned data. We are going to ignore the actual minimizer, and instead just compute the negative log likelihood (nll).
# from jax.config import config
# config.update("jax_enable_x64", True)
import numpy as np
np.random.seed(42)
dist = np.hstack(
[
np.random.normal(loc=1, scale=2.0, size=1_000_000),
np.random.normal(loc=1, scale=0.5, size=1_000_000),
]
)
Let's start with NumPy, just to show how it would be done:
def gaussian(x, μ, σ):
return 1 / np.sqrt(2 * np.pi * σ**2) * np.exp(-((x - μ) ** 2) / (2 * σ**2))
def add(x, f_0, μ, σ, σ2):
return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)
def nll(x, f_0, μ, σ, σ2):
return -np.sum(np.log(add(x, f_0, μ, σ, σ2)))
%%time
nll(dist, *np.random.rand(4))
%%timeit
nll(dist, *np.random.rand(4))
Jax is a tool from Google. It can target a wide variety of backends (CPU, GPU, TPU), can JIT compile, and can take gradients. It is very powerful, and rather tricky, since it does quite a few things a bit differently. First let's try using it:
import jax
import jax.numpy as jnp
Now we'll just replace np
with jnp
everywhere in the above code, to produce:
def gaussian(x, μ, σ):
return 1 / jnp.sqrt(2 * jnp.pi * σ**2) * jnp.exp(-((x - μ) ** 2) / (2 * σ**2))
def add(x, f_0, μ, σ, σ2):
return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)
def nll(x, f_0, μ, σ, σ2):
return -jnp.sum(jnp.log(add(x, f_0, μ, σ, σ2)))
Now we need just one more step - we need Jax arrays instead of NumPy arrays:
d_dist = jnp.asarray(dist)
There's one more step, but let's just check this first:
%%time
nll(d_dist, *np.random.rand(4)).block_until_ready()
%%timeit
nll(d_dist, *np.random.rand(4)).block_until_ready()
We probably are seeing a nice speedup here. File it away - we'll explain it later, and let's move on.
Now we can JIT our function. Unlike numba, we just pass the top level function in.
nll_jit = jax.jit(nll)
Now the first time we call it, JAX will "trace" the function and produce the XLA code for it. Like other tracers, it can't handle non-vectorized control flow.
%%time
nll_jit(d_dist, *np.random.rand(4)).block_until_ready()
Now that it's primed, let's measure:
%%timeit
nll_jit(d_dist, *np.random.rand(4)).block_until_ready()
This is very nice, but there is a caveat; this is in 32 bit mode. Uncomment the code at the top and restart the kernel; compare the timings again.