import pyhf
pyhf.set_backend('jax')
import jax
import jaxlib
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
Welcome to this tutorial on automatic differentiation. Automatic Differentiation is a method to compute exact derivatives of functions implements as programs. It's a widely applicable method and famously is used in many Machine learning optimization problems. E.g. neural networks, which are parametrized by weights $\text{NN}(\text{weights})$ are trained by (stocastic) gradient descent to find the minimum of the loss function $L$ where
$$\text{weights}_\text{opt} = \text{argmin}_\text{weights} L(\text{weights}) \hspace{1cm} \nabla L(\text{weights}) = 0$$This means that efficient algorithms to compute derivatives are crucial.
Aside from ML, many other use-cases require gradients: standard statistical analysis in HEP (fitting, hypothesis testing, ...) requires gradients. Uncertainty propagation (e.g. track parameters) uses gradients, etc..
Before diving into automatic differentiation, let's review how my might otherwise compute derivatives
A common appraoch to approximate gradients of a black-box function is to evaluate it at close-by points $x$ and $x+Δx$ and
$\frac{\partial f}{\partial x} \approx \frac{f(x) - f(x+\Delta x}{\Delta x}$ if $\Delta x$ is sufficiently small
def black_box_func(x):
return x**3
def true_gradient_func(x):
return 3*x**2
def plot_gradients(nsteps,title):
xi = np.linspace(-5,5,nsteps)
yi = black_box_func(xi)
approx_gradient = np.gradient(yi,xi)
true_gradient = true_gradient_func(xi)
plt.plot(xi,yi, label = 'black-box func')
plt.scatter(xi,yi)
plt.plot(xi,approx_gradient, label = 'finite diff grad')
plt.scatter(xi,approx_gradient)
plt.plot(xi,true_gradient, label = 'true grad')
plt.scatter(xi,true_gradient)
plt.legend()
plt.title(title)
plt.show()
plot_gradients(11, title = 'it is pretty bad if Δx is too large')
plot_gradients(41, title = 'it gets better at the cost of many evaluations')
while only approximate, finite differences is *simple*. I don't need to know anything about the function beyond having the ability to *evaluate* it
This way I can compute gradients of functions encoded as a computer program, and it works in any programming language
For multivariate (possibly vector-valued) functions $\vec{f}(\vec{x}) = f_i(x_1,x_2,\dots,x_n)$ one needs to compute a finite difference gradient for each partial derivative $\frac{\partial f}{\partial x}$ in order to get the full jacobian / total derivative $df_i = J_{ik} dx_k\; J_{ik} = \frac{\partial f_i}{\partial x_k}$
In high dimensions, the number of required evaluations explodes!
Finite Differences:
Computer Algebra Systems (CAS), such as Mathematica (or sympy) can manipulate functional expressions and know about differentiation rules (and many other things)
If the function / the prograrm which we want to derive is available as such an expression the symbolic differentiation can produce exact gradients
import sympy
def function(x):
return x**3
def true_deriv(x):
return 3*x**2
symbolic_x = sympy.symbols('x')
symbolic_func = function(symbolic_x)
symbolic_func
Using lambdify
we can turn it into a normal python function we can evaluate
xi = np.linspace(-5,5,11)
yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi)
<matplotlib.collections.PathCollection at 0x13743a7b8>
symbolic_func
is now an experssion which we can differentiate symbolically
symbolic_deriv = symbolic_func.diff(symbolic_x)
symbolic_deriv
def plot_symbolic(nsteps,title):
xi = np.linspace(-5,5,nsteps)
yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)
plt.scatter(xi,yi)
plt.plot(xi,yi, label = 'function')
yi = true_deriv(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi, label = 'true deriv')
yi = sympy.lambdify(symbolic_x,symbolic_deriv)(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi, label = 'symbolic deriv')
plt.legend()
plt.title(title)
plt.show()
plot_symbolic(11,title = 'the symbolid derivative is always exact')
plot_symbolic(4, title = 'it does not matter where/how often you evaluate it')
We can even handle function compositions
def f1(x):
#standard operations are overloaded
return x**2
def f2(x):
#note here we use a special cos function from sympy
#instead of e.g. np.cos or math.cos
return sympy.cos(x)
composition = f2(f1(symbolic_x))
composition
composition.diff(symbolic_x)
Since sympy
knows about the chain rule it can differentiate accordingly
This looks great! We get exact derivatives. However, there are drawbacks
Most functions we are interested in are not implemented e.g. Mathematica. Rather we have loads of C, C++, Python code that we are interested in.
But ok, sympy
alleviates this to some degree. The functions
f1
and f2
are fairly generic since they use operator
overloading. So a symbolic program and a "normal" program
could only differ by a few import statements
from sympy import cos
def f1(x):
return x**2
def f2(x):
return cos(x)
versus:
from math import cos
def f1(x):
return x**2
def f2(x):
return cos(x)
Note the code is almost exactly the same
But not all our functions are so simple!
Expression swell
Let's look at a quadratic map which is applied a few times
def quadmap(x):
return x**2 + 3*x + 4
def func(x):
for i in range(6):
x = quadmap(x)
return x
quad_6_times = func(symbolic_x)
quad_6_times
This looks pretty intimidating. What happened? Symbolic programs run through the prgram and accumulate the full program into a single expression
If we would just blindly differentiate this it would look like this
quad_6_times.diff(symbolic_x)
This looks even worse!
Also note that that if we just blindly substitute x for some value e.g. x=2, we would be computing a lot of the same terms manyt times. E.g. in the above expression $x^2+3x+4$ appears in a lot of places due to the "structure' of the original progrm
If you knew the structure of the program you likely could precompute some of these repeating terms. However once it got all expanded all this knowledge about the structure is gone!
Modern CAS can recover some of this by finding "common subexpressions" (CSE)
sympy.cse(quad_6_times)
([(x0, x**2), (x1, (3*x + x0 + 4)**2), (x2, (9*x + 3*x0 + x1 + 16)**2), (x3, (27*x + 9*x0 + 3*x1 + x2 + 52)**2), (x4, (81*x + 27*x0 + 9*x1 + 3*x2 + x3 + 160)**2)], [729*x + 243*x0 + 81*x1 + 27*x2 + 9*x3 + 3*x4 + (243*x + 81*x0 + 27*x1 + 9*x2 + 3*x3 + x4 + 484)**2 + 1456])
But it's not as automatic and may note find all relevant subexpressions. In any case it's trying hard to recover some of the structure that is already implicitly present in the prograam we want to differentiate
Control Flow
In addition to looping constucts like above, a lot of the functions we are interested in have control flow structures like if/else statements, while loops, etc..
If we try to create a symbolic expression with conditionals we fail badly
def func(x):
if x > 2:
return x**2
else:
return x**3
symbolic_result = func(symbolic_x)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-12-b04e2306166c> in <module> 5 return x**3 6 ----> 7 symbolic_result = func(symbolic_x) <ipython-input-12-b04e2306166c> in func(x) 1 def func(x): ----> 2 if x > 2: 3 return x**2 4 else: 5 return x**3 ~/Code/pyhfdev/dev/pyhfdevenv/lib/python3.7/site-packages/sympy/core/relational.py in __nonzero__(self) 382 383 def __nonzero__(self): --> 384 raise TypeError("cannot determine truth value of Relational") 385 386 __bool__ = __nonzero__ TypeError: cannot determine truth value of Relational
That's too bad because this is a perfectly respectable functino almost everywhere
xi = np.linspace(-2,5,1001)
yi = np.asarray([func(xx) for xx in xi])
plt.plot(xi,yi)
plt.scatter(xi,yi)
plt.title("pretty smooth except at x=2")
plt.show()
If we could afford finite diffences it would compute gradients just fine.
g = np.gradient(yi,xi)
plt.plot(xi,g)
plt.scatter(xi,g)
plt.ylim(-2,10)
plt.title('''\
parabolesque gradient in x^3 region,
linear in x^2 region as expected''');
In short: symbolic differentiation is not our saving grace.
To recap:
Finite differences is
Symbolic differentiation is:
One, that is
This third approach is 'Automatic' differentiation.
Before we start, let's first look at linear transformations* from ℝᵐ → ℝⁿ: $$y(x) = Ax$$
With a given basis, this is representable as a (rectangular0 matrix: $$y_i(x) = A_{ij}x_j$$
For a given linear problem, there are few ways we can run this computation
full matrix computation
i.e. we store the full (dense) $nm$ elements of the rectangular matrix and compute an explicit matrix multiplication.
The computation can be fully generic for any matrix
def result(matrix, vector):
return np.matmul(matrix,vector)
sparse matrix computation
If many $A_ij=0$, it might be wasteful to expend memory on them. We can just create a sparse matrix, by
The computation can be kept general
def result(sparse_matrix, vector):
return sparse_matmul(sparse_matrix,vector)
matrix-free computation
In many cases a linear program is not explicitly given by a Matrix, but it's given as code / a "black-box" function. As long as the computation in the body of keeps to (hard-coded) linear transformation the program is linear. The matrix elements are no longer explicitly enumerated and stored in a data structure but implicitly defined in the source code.
This is not anymore a generic computation, but each linear transformation is its own program. At the same time this is also the most memory efficient representation. No lookup table is needed since all constants are hard-coded.
def linear_program(vector):
z1,z2 = 0,0
z1 += A_11*x1
z2 += A_12*x2
z2 += A_22*x2
return [z1,z2]
def matrix_vector_product(x):
x1,x2,x3 = x
z1,z2 = 0,0
z1 += 2*x1 #MVP statement 1
z2 += 1*x2 #MVP statement 2
z2 += 3*x3 #MVP statement 3
return np.asarray([z1,z2])
M = np.concatenate([
matrix_vector_product(np.asarray([1,0,0])).reshape(-1,1),
matrix_vector_product(np.asarray([0,1,0])).reshape(-1,1),
matrix_vector_product(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from matrix-vector products:\n{M}')
M derived from matrix-vector products: [[2 0 0] [0 1 3]]
The same matrix induces a "dual" linear map: ℝⁿ → ℝᵐ $$ x_k = y_i A_{ik}$$
i.e. instead of a Matrix-vector product it's now a vector-Matrix product (VMP)
If one has access to a "vector-Matrix" program corresponding to a matrix $A$ one can again -- as in the MVP-case -- recover the matrix elements, by feeding in basis vectors.
This time the matrix is built one row at a time
def vector_matrix_product(z):
x1,x2,x3 = 0,0,0
z1,z2 = z
x3 += z2*3 #VMP version of statement 3
x2 += z2*1 #VMP version of statement 2
x1 += z1*2 #VMP version of statement 1
return np.asarray([x1,x2,x3])
M = np.concatenate([
vector_matrix_product(np.asarray([1,0])).reshape(1,-1),
vector_matrix_product(np.asarray([0,1])).reshape(1,-1),
],axis=0)
print(f'M derived from vector-matix products:\n{M}')
M derived from vector-matix products: [[2 0 0] [0 1 3]]
For a given linear transformation, characterized by a matrix $A_{ij}$ we have a forward (matrix-vector) and backward (vector-matrix) map $$y_i = A_{ij}x_k$$ $$x_j = y_i A_{ij}$$
and we can use either to recover the full matrix $A_{ij}$
If you look at the code above, you'll notice that the number of calls necessary to the MVP or VMP program is related to the dimensions of matrix itself.
For a $n\times m$ matrix (for a map: ℝᵐ → ℝⁿ), you need as $m$ calls to the "Matrix-vector" program to built the full matrix one-column-at-a-time. Likewise you need $n$ calls to the "vector-Matrix" program to build the matrix one-row-at-a-time.
This becomes relevant for very asymmetric maps: e.g. scalar maps from very high-dimensional spaces $\mathbb{R}^{10000} \to \mathbb{R}$ the "vector-Matrix" appraoch is vastly more efficient than the "Matrix-vector one. There's only one row, so only one call too the VMP program is needed to construct the full matrix!
Similarly, functions mapping few variables into very high dimensional spaces $\mathbb{R} \to \mathbb{R}^{10000}$ it's the opposite: the "Matrix-vector" approach is much better suited than the "vector-Matrix" one (this time it's a single column!).
Of course copositions $(f\circ g)(x) = f(g(x))$ of linear maps are also linear, so the above applies.
Depending on whether the "Matrix-vector" or "vector-Matrix" appraoch is used, the data is propagated forwards or backwards.
The "vector-Matrix" or "Matrix-vector" picture can be generalized to arrbitrary directed acyclic graphs.
(one could in principle always recove a rectangular/matrix-like version of a DAG by inserting trivial nodes)
def graph_like(x):
x1,x2,x3 = x
y1 = 2*x1+x2
z1,z2 = y1+2*x3,x3-y1 #note that we reach "over" the "ys" and diectly touch x_n
return np.asarray([z1,z2])
def matrix_like(x):
x1,x2,x3 = x
y1 = 2*x1+x2
y2 = x3 #can just introduce a dummy variable to make it matrix-like
z1,z2 = y1+2*x3,y2-y1
return np.asarray([z1,z2])
M = np.concatenate([
matrix_like(np.asarray([1,0,0])).reshape(-1,1),
matrix_like(np.asarray([0,1,0])).reshape(-1,1),
matrix_like(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from matrix like computation:\n{M}')
M = np.concatenate([
graph_like(np.asarray([1,0,0])).reshape(-1,1),
graph_like(np.asarray([0,1,0])).reshape(-1,1),
graph_like(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from graph-like products:\n{M}')
M derived from matrix like computation: [[ 2 1 2] [-2 -1 1]] M derived from graph-like products: [[ 2 1 2] [-2 -1 1]]
Why are we talking about linear transformations? After all lot of the code we write is non-linear! However, derivatives are always linear.
And derivatives (the jacobian) of a composition $f\circ g$ is the composition of linear derivatives (the jacobians of each map) i.e. the full jacobian Matrix is the result of multipying all Jacobians of the composition. $$J = J_0 J_1 J_2 J_3 \dots J_n $$
(This is just the chain rule) $$z = f(y) = f(g(x))\hspace{1cm} \frac{\partial f_i}{\partial x_j} = \frac{\partial f_i}{\partial z_j}\frac{\partial z_j}{\partial x_k}$$
I.e. finding derivatives, means characterizing the jacobian matrix. From the above discussion, we can use the "Jacobian-vector product" (JVP, builds Jacobians column-wise) or "vector-Jacobian product" (builds Jacobians row-wise) approach.
In the language of automatic differentiation
Let's work this out on a very simple problem
In the forward pass we use "Matrix-vector" products and need to do two evaluation
In the backward pass we use "vector-Matrix" products and need to do only a single evaluation
Both approaches give the same result. Since this is a map from $\mathbb{R}^2 \to \mathbb{R}^1$ the backward pass is more efficient than the forward pass
Let's look at a real-life example
$$z(x_1,x_2) = y + x_2 = x_1x_2 + x_2$$This is easy python code
def mul_func(x1,x2):
return x1*x2
def sum_func(x1,x2):
return x1+x2
def function(x):
x1,x2 = x
y = mul_func(x1,x2)
z = sum_func(y,x2)
return z
print(function([2,4]))
12
In the forward pass, an autodiff system needs to create a JVP implementation for each elementary operation
def mul_jvp(x1,dx1,x2,dx2):
y = mul_func(x1,x2)
dy = x1*dx2 + x2*dx1
return y, dy
def sum_jvp(x1,dx1,x2,dx2):
return sum_func(x1,x2), dx1 + dx2
def function_jvp(x,dx):
x1,x2 = x
dx1,dx2 = dx
y, dy = mul_jvp(x1,dx1,x2,dx2)
z, dz = sum_jvp(y,dy, x2, dx2)
return z,dz
Since in the forward pass we build "column-at a time" and our final jacobian is has shape (1x2), i.e. two columns we need two forward passes to get the full Jacobian. Not that for eacch forward pass we also get the fully computed functino value delivered on top!
Also note that the "JVP" version of the functino has the same structure as the original function. For each call in the original program there is an equivalent call in the JVP program. However the JVP call does always two things at once
So it has roughly 2x the run-time as the original program (depending on the complexity of the derivatives). Said another way: computing the one-pass in the derivative has the same computational complexity as the function itself.
print(function_jvp([2,4],[1,0]))
print(function_jvp([2,4],[0,1]))
(12, 4) (12, 3)
For the backward pass we build "row-at-a-time'. For each elementary operation we need to build a VJP implementation
def mul_vjp(x1,x2,dx1,dx2,dout):
dx2 += dout * x1
dx1 += dout * x2
return dx1,dx2
def sum_vjp(x1,x2,dx1,dx2,dout):
dx1 += dout * 1
dx2 += dout * 1
return dx1,dx2
def function_vjp(x,dz):
#run forward
x1,x2 = x
y = mul_func(x1,x2)
z = sum_func(y,x2)
#zero gradients
dy = 0
dx1 = 0
dx2 = 0
#run backward
dy,dx1 = sum_vjp(y,x1, dy, dx1, dz)
dx1,dx2 = mul_vjp(x1,x2, dx1, dx2, dy)
return z,[dx1,dx2]
Here, we see the power of backward propagation (or the reverse mode) we get all gradients of the single row ine oone go. Since this Jacobian only has one row, we're done! And we get the function value delivered on top of the gradients as well!
print(function_vjp([2,4],1.0))
(12, [4.0, 3.0])
Again, let's look at the "VJP" code. The forward pass is exactly the same as the original function. This just records the final result and all intermediate values, which we will need for the backward pass.
Moving on to the backward pass, we see again, as in JVP, it has the same structure as the forward pass. For each call to a subroutine there is an equivalent call in the backward pass to compute the VJP.
As in the JVP case, the computational complexity of one backward pass is roughly the same as the forward pass. Now unlike the JVP-case we only needed a single pass for all the gradients of this scalar function. So obtaining the full gradient of a function is only as expensive as the function itself.
Above we have built a manual autodiff system. Let's recap what we needed to do
Once given a program, we had to do the following
In the forward mode:
In the backward mode:
All of this is pretty mechanistic and hence "automatable". And given that it's a very narrow domain of only implementing JVP/JVP operations this is easy to do in any language.
That's why it's automatic differentiation
What we gain from this is that we get
Some notes on pros and cons:
In the forward mode:
the signature of each opeartion basically extends
float f(float x,float y,float z)
to
pair<float> f(float x,float dx,float y,float float dy, float z,float dz)
if you use composite types ("dual numbers") that hold both x,dx you can basically keep the signature unchanged
f(dual x, dual x, dual z)
together with operator overloading on these dual types e.g. dual * dual
you can
essentially keep the source code unchanged
float f(float x, float y): return x*y
->
dual f(dual x,dual y): return x*y
That means it's very easy implement. And memory efficient, no superfluous values are kept when they run out of scope.
But forward more better for vector-value functions of few parameters
In the reverse mode:
Yes there are! And a lot of them in many languages. On the othe rhand, try finding CAS systems in each of those
This is PyHEP, so let's focus on Python. Here, basically what you think of as "Machine Learning frameworks" are at the core autodiff libraries
Let's focus on jax
import jax
import jax.numpy as jnp
def f(x):
return x**2
jax.numpy
is almost a drop-in rerplacement for numpy
. I do import jax.numpy as jnp
but if you're daring you could do import jax.numpy as np
x = jnp.array([1,2,3])
y = jnp.array([2,3,4])
/Users/lukasheinrich/Code/pyhfdev/dev/pyhfdevenv/lib/python3.7/site-packages/jax/lib/xla_bridge.py:125: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
print(x+y)
print(x*y)
print(jnp.log(x))
print(jnp.exp(y))
[3 5 7] [ 2 6 12] [0. 0.69314718 1.09861229] [ 7.3890561 20.08553692 54.59815003]
def f(x):
return x**3
print(f(4.0))
print(jax.grad(f)(4.0)) #boom!
print(jax.grad(jax.grad(f))(4.0)) #boom!
print(jax.grad(jax.grad(jax.grad(f)))(4.0)) #boom!
print(jax.grad(jax.grad(jax.grad(jax.grad(f))))(4.0)) #boom!
64.0 48.0 24.0 6.0 0.0
xi = jnp.linspace(-5,5)
yi = f(xi)
plt.plot(xi,yi)
[<matplotlib.lines.Line2D at 0x137b49be0>]
jax.grad(f)(xi)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-31-9fa07a940ab8> in <module> ----> 1 jax.grad(f)(xi) ~/Code/pyhfdev/dev/pyhfdevenv/lib/python3.7/site-packages/jax/api.py in grad_f(*args, **kwargs) 411 @wraps(fun, docstr=docstr, argnums=argnums) 412 def grad_f(*args, **kwargs): --> 413 _, g = value_and_grad_f(*args, **kwargs) 414 return g 415 ~/Code/pyhfdev/dev/pyhfdevenv/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs) 472 else: 473 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) --> 474 _check_scalar(ans) 475 dtype = dtypes.result_type(ans) 476 tree_map(partial(_check_output_dtype_grad, holomorphic), ans) ~/Code/pyhfdev/dev/pyhfdevenv/lib/python3.7/site-packages/jax/api.py in _check_scalar(x) 493 if isinstance(aval, ShapedArray): 494 if aval.shape != (): --> 495 raise TypeError(msg("had shape: {}".format(aval.shape))) 496 else: 497 raise TypeError(msg("had abstract value {}".format(aval))) TypeError: Gradient only defined for scalar-output functions. Output had shape: (50,).
Whoops, jax.grad defaults to reverse mode with a single backward pass, but through broadcasting we get a vector -> vector
map. We can use some jax magic to "unbroadcast" the function, take the gradient and re-broadcast it
jax.vmap(jax.grad(f))(xi)
DeviceArray([7.50000000e+01, 6.90024990e+01, 6.32548938e+01, 5.77571845e+01, 5.25093711e+01, 4.75114536e+01, 4.27634319e+01, 3.82653061e+01, 3.40170762e+01, 3.00187422e+01, 2.62703040e+01, 2.27717618e+01, 1.95231154e+01, 1.65243648e+01, 1.37755102e+01, 1.12765514e+01, 9.02748855e+00, 7.02832153e+00, 5.27905040e+00, 3.77967514e+00, 2.53019575e+00, 1.53061224e+00, 7.80924615e-01, 2.81132861e-01, 3.12369846e-02, 3.12369846e-02, 2.81132861e-01, 7.80924615e-01, 1.53061224e+00, 2.53019575e+00, 3.77967514e+00, 5.27905040e+00, 7.02832153e+00, 9.02748855e+00, 1.12765514e+01, 1.37755102e+01, 1.65243648e+01, 1.95231154e+01, 2.27717618e+01, 2.62703040e+01, 3.00187422e+01, 3.40170762e+01, 3.82653061e+01, 4.27634319e+01, 4.75114536e+01, 5.25093711e+01, 5.77571845e+01, 6.32548938e+01, 6.90024990e+01, 7.50000000e+01], dtype=float64)
that looks better!
jax.grad(f)
just returns another function. Of course we can just
take the gradient of that as well. And so on...
g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.plot(xi,g2i, label = "f''")
plt.plot(xi,g3i, label = "f'''")
plt.legend()
<matplotlib.legend.Legend at 0x13841ecf8>
Back when discussing symbolic differentiation we hit a snag when adding control flow through to our prorgam. In Jax this just passes through transparently.
Let's compare this to finite differences. So far the only system we had to compute derivatives of control-flow-ful programs
def control_flow_func(x):
if x > 2:
return x**2
else:
return x**3
first_gradient_of_cflow = jax.grad(control_flow_func)
xi = jnp.linspace(-2,5,101)
yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])
plt.plot(xi,yi,c = 'k')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])
plt.scatter(xi,yi, label = 'jax autodiff')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([control_flow_func(xx) for xx in xi])
plt.scatter(xi,np.gradient(yi,xi), label = 'finite differences')
plt.legend()
<matplotlib.legend.Legend at 0x138772da0>
We can start to see the benefits autodiff. Among other things, finite differnces becomes quite sensitive to exactly where the evaluation points are (e.g. wrt to the discontinuity)
As we compute higher derivatives, this error compounds badly for finite differences. But for autodiff, it's smooth sailing!
second_gradient_of_cflow = jax.grad(first_gradient_of_cflow)
xi = jnp.linspace(-2,5,101)
yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])
plt.plot(xi,yi,c = 'k')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])
plt.scatter(xi,yi, label = '2nd deriv jax autodiff')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([control_flow_func(xx) for xx in xi])
plt.scatter(xi,np.gradient(np.gradient(yi),xi), label = '2nd deriv finite differences',)
plt.legend()
<matplotlib.legend.Legend at 0x1389e0d68>
Not all our programs are so simple. Consider this
def func(x)
y_root = solve(x^2 + y^2 == 1,x = x, y_start = 2.0)
return y_root
solving this often goes through some iterative algorithm like Brent bracketing But, differentiating through the iteration is not the right solution.
We can add our own custom gradients
Recall the implicit function theorem $$ f(x,y) = x^2 + y^2 -1 = 0 $$
$$ df = 0 \leftrightarrow 2x dx + 2y dy = 0 $$$$ dy/dx = -x/y $$How do we teach this an autodiff system:
Recall:
import jax
from jax import core
import numpy as np
from jax.interpreters import ad
import scipy
import functools
import matplotlib.pyplot as plt
def findroot(f):
return scipy.optimize.brentq(f,a = 0,b = 10)
def func(x,y):
return x**2 + y**2 - 1
def y_for_x(x):
return findroot(functools.partial(func,x))
xi = np.linspace(-1,1)
yi = np.asarray([y_for_x(xx) for xx in xi])
plt.plot(xi,yi)
findrootjax_p = core.Primitive('findrootjax')
findrootjax_p.def_impl(lambda x: y_for_x(x))
ad.defvjp(findrootjax_p, lambda g, x: - x / y_for_x(x))
def findrootjax(x):
return findrootjax_p.bind(x)
jax.value_and_grad(findrootjax)(0.5)
xi = np.linspace(-1,1,101)
yi = [findrootjax(v) for v in xi]
plt.plot(xi,yi)
xi = np.linspace(-1,1,21)
vg = np.asarray([np.asarray(jax.value_and_grad(findrootjax)(v)) for v in xi])
plt.scatter(xi,vg[:,0])
plt.quiver(xi,vg[:,0],np.ones_like(vg[:,0]),vg[:,1],
angles = 'uv',
alpha = 0.5,
)
plt.gcf().set_size_inches(5,2.5)
Of course we can use automatic differentiation for neural networks. But other things in HEP also can make use of gradients. A prime example where this is the case is statistical analysis
For a maximum likelihood fit we want to minimize the log likelihood
$\theta^* = \mathrm{argmin}_\theta(\log L)$
import jax
import jax.numpy as jnp
import numpy as np
import pyhf
import matplotlib.pyplot as plt
pyhf.set_backend('jax')
m = pyhf.simplemodels.hepdata_like([5.],[10.],[3.5])
pars = jnp.array(m.config.suggested_init())
data = jnp.array([15.] + m.config.auxdata)
m.logpdf(pars,data)
DeviceArray([-4.25748227], dtype=float64)
bestfit = pyhf.infer.mle.fit(data,m)
bestfit
DeviceArray([1., 1.], dtype=float64)
grid = x,y = np.mgrid[.5:1.5:101j,.5:1.5:101j]
points = np.swapaxes(grid,0,-1).reshape(-1,2)
v = jax.vmap(m.logpdf, in_axes = (0,None))(points,data)
v = np.swapaxes(v.reshape(101,101),0,-1)
plt.contourf(x,y,v, levels = 100)
plt.contour(x,y,v, levels = 20, colors = 'w')
grid = x,y = np.mgrid[.5:1.5:11j,.5:1.5:11j]
points = np.swapaxes(grid,0,-1).reshape(-1,2)
values, gradients = jax.vmap(
jax.value_and_grad(
lambda p,d: m.logpdf(p,d)[0]
), in_axes = (0,None)
)(points,data)
plt.quiver(
points[:,0],
points[:,1],
gradients[:,0],
gradients[:,1],
angles = 'xy',
scale = 75
)
plt.scatter(bestfit[0],bestfit[1], c = 'r')
plt.xlim(0.5,1.5)
plt.ylim(0.5,1.5)
plt.gcf().set_size_inches(5,5)