import jax
import jax.numpy as jnp
import numpy as np
def f(x):
return x**3 + x**2 + x
f(1.0)
fp = jax.grad(f)
fp(1.0)
fpp = jax.grad(fp)
fpp(1.0)
Let's watch the tracer:
def f(x):
print(f"{x = }")
y = x**2
print(f"{y = }")
return y
f_jit = jax.jit(f)
f_jit(2)
f_jit(2)
Notice that the Python code runs once, and something that is not an integer at all is being passed in. From then on, the function doesn't run the Python code anymore. Well, as long as you use the same input types / shapes:
f_jit(1.0)
f_jit(1.0)
f_jit(1)
You can't trace through flow control that depends on the tracers, or dynamically change the shape of the array:
@jax.jit
def broken(x):
if x == 3:
return x**3
return x
broken(2)
Unlike NumPY, Jax arrays are immutable. You also should write pure functions (ones without side effects / state).
For example, you can't do an in-place set:
jarr = jnp.zeros((3, 3))
jarr[np.diag(np.ones(3, dtype=bool))] = 1
jarr
Jax provides a trick to make this easy to do while avoiding an in-place mutation:
j1 = jnp.zeros((3, 3))
j2 = j1.at[np.diag(np.ones(3, dtype=bool))].set(1)
j2