#!/usr/bin/env python # coding: utf-8 # # correctionlib-gradients @ AGC Demo Day 30/11/2023 # # ### Enrico Guiraud with help from Nick Smith and Lukas Heinrich # # ![image.png](attachment:2aae49a8-aafb-4212-b2d2-ab45603a896b.png) # ## What # # A [JAX](https://jax.readthedocs.io)-friendly, autodifferentiable, Python-only implementation of a subset of [correctionlib](https://github.com/cms-nanoAOD/correctionlib) correction evaluations. # # If you don't know what these terms mean, for the purposes of this presentation: # - automatic differentiation aka autodiff is a method to compute gradients of numerical functions: `f(x) -> df/dx(x)` # - [JAX](https://jax.readthedocs.io) is (among other things) a Python tool to make functions that operate on arrays autodifferentiable # - [correctionlib](https://github.com/cms-nanoAOD/correctionlib) is a JSON-based schema for HEP corrections as well as the reference implementation for the evaluation of those corrections, by Nick Smith # # As an exercise, I tried to follow recommendations from https://learn.scientific-python.org as closely as possible when working on the project. # ## Why # # To enable autodifferentiation of full HEP analysis pipelines, we need to make each component of the pipeline autodifferentiable. # # This project makes one more piece of the pipeline autodifferentiable, namely correction evaluation. # ## Installation # # ```console # pip install correctionlib-gradients # ``` # ## Example usage # In[1]: import jax from correctionlib import schemav2 from correctionlib_gradients import CorrectionWithGradient # Given a correctionlib schema (which can be deserialized from JSON using correctionlib)... # In[2]: formula_schema = schemav2.Correction( name="RMS of x and y", version=2, inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")], output=schemav2.Variable(name="a scale", type="real"), data=schemav2.Formula( nodetype="formula", expression="sqrt((x * x + pow(y, 2)) / 2)", parser="TFormula", variables=["x", "y"], ), ) # ...construct a `CorrectionWithGradient` object and you are done: # In[3]: corr = CorrectionWithGradient(formula_schema) # You can use c.evaluate as a JAX-friendly, auto-differentiable function that composes with any other JAX-friendly operation: # In[4]: df = jax.value_and_grad(corr.evaluate, argnums=[0, 1]) value, grad = df(0.0, 1.0) print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}") # `jax.jit` works too (for Formulas; WIP on Binning): # In[5]: jitted_df = jax.jit(jax.value_and_grad(corr.evaluate, argnums=[0, 1])) value, grad = jitted_df(0.0, 1.0) print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}") # `jax.vmap` can be used to vectorize the evaluation over multiple rows of inputs: # In[6]: xs = jax.numpy.array([0.0, 1.0]) ys = jax.numpy.array([1.0, 0.0]) vec_df = jax.vmap(jax.jit(jax.value_and_grad(corr.evaluate, argnums=[0, 1]))) value, grad = vec_df(xs, ys) print(f"{value = }\ndf/dx = {grad[0]}\ndf/dy = {grad[1]}") # ## Supported types of corrections # # Currently the following corrections from `correctionlib.schemav2` are supported: # # - `Formula`, including parametrical formulas # - `Binning` with uniform or non-uniform bin edges and `flow="clamp"`; bin contents can be either: # - all scalar values # - all `Formula` or `FormulaRef` # - scalar constants # ## Open questions # # - what are other types of corrections that definitely need support? `MultiBinning`? # - what is the "right" differentiable relaxation of a histogram? [more discussion in #gradhep](https://iris-hep.slack.com/archives/C0155BGPGE4/p1699120426133849) # - what is the "right" way to differentiate a composition of `Binning` and `Formula`s (i.e. formulas defined piece-wise)? # - several operations are not JAX-traceable (`jax.grad` works with them but not `jax.jit`, `jax.vmap` etc.). can we do better? # - scipy's cubic splines, that we use as a differentiable approximation of a 1D histogram profile, are not JAX-traceable # - bin look-ups, that we perform in case of a `Formula` inside a `Binning`, are not JAX-traceable # - further development: I am moving on soon. I plan to add a few obvious improvements, but I won't be there to develop this in the years to come # ### An example histogram (blue) with one possible differentiable relaxation (orange) # # ![image.png](attachment:6641f60c-1892-4c2b-84ae-856906c996f2.png) # # What is the "right" relaxation? # ### An example compound histogram (some bins contain Formulas) # # ![image.png](attachment:a60ab5fe-632b-4de6-a001-7076f21cd9bd.png) # # What is the "right" relaxation? # ## Bonus: what about C++? # # - autodifferentiable analyses are mostly a Python R&D topic, not a lot of interest in end-to-end differentiable analyses in C++ in the community # - via Python+JAX we can make fast progress, learn about the problem, the solution space, the technical hurdles, and then go for a C++ implementation in the future if needed # - at the time of writing I do not know of a C++ library that can autodifferentiate a compute graph evaluation the same way JAX does in this package (and implementing such a C++ library is a project of a larger scope than we feel like tackling at this point) # - I think the best option for C++ is implementing a code generator that produces code that evaluates a given correction, then pass the generated correction evaluation code through [enzyme](https://enzyme.mit.edu) or [clad](https://github.com/vgvassilev/clad)