A JAX-friendly, autodifferentiable, Python-only implementation of a subset of correctionlib correction evaluations.
If you don't know what these terms mean, for the purposes of this presentation:
f(x) -> df/dx(x)
As an exercise, I tried to follow recommendations from https://learn.scientific-python.org as closely as possible when working on the project.
To enable autodifferentiation of full HEP analysis pipelines, we need to make each component of the pipeline autodifferentiable.
This project makes one more piece of the pipeline autodifferentiable, namely correction evaluation.
pip install correctionlib-gradients
import jax
from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient
Given a correctionlib schema (which can be deserialized from JSON using correctionlib)...
formula_schema = schemav2.Correction(
name="RMS of x and y",
version=2,
inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
expression="sqrt((x * x + pow(y, 2)) / 2)",
parser="TFormula",
variables=["x", "y"],
),
)
...construct a CorrectionWithGradient
object and you are done:
corr = CorrectionWithGradient(formula_schema)
You can use c.evaluate as a JAX-friendly, auto-differentiable function that composes with any other JAX-friendly operation:
df = jax.value_and_grad(corr.evaluate, argnums=[0, 1])
value, grad = df(0.0, 1.0)
print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}")
value = Array(0.70710677, dtype=float32, weak_type=True) df/dx = 0.0 df/dy = 0.7071067690849304
jax.jit
works too (for Formulas; WIP on Binning):
jitted_df = jax.jit(jax.value_and_grad(corr.evaluate, argnums=[0, 1]))
value, grad = jitted_df(0.0, 1.0)
print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}")
value = Array(0.70710677, dtype=float32, weak_type=True) df/dx = 0.0 df/dy = 0.7071067690849304
jax.vmap
can be used to vectorize the evaluation over multiple rows of inputs:
xs = jax.numpy.array([0.0, 1.0])
ys = jax.numpy.array([1.0, 0.0])
vec_df = jax.vmap(jax.jit(jax.value_and_grad(corr.evaluate, argnums=[0, 1])))
value, grad = vec_df(xs, ys)
print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}")
value = Array([0.70710677, 0.70710677], dtype=float32) df/dx = [0. 0.70710677] df/dy = [0.70710677 0. ]
Currently the following corrections from correctionlib.schemav2
are supported:
Formula
, including parametrical formulasBinning
with uniform or non-uniform bin edges and flow="clamp"
; bin contents can be either:Formula
or FormulaRef
MultiBinning
?Binning
and Formula
s (i.e. formulas defined piece-wise)?jax.grad
works with them but not jax.jit
, jax.vmap
etc.). can we do better?Formula
inside a Binning
, are not JAX-traceableWhat is the "right" relaxation?
What is the "right" relaxation?