! pip install -U -q jax jaxlib
import jax.numpy as jnp
from jax.experimental.ode import odeint
def f(u, t, sigma, rho, beta):
x, y, z = u
return jnp.array([sigma * (y - x),
x * (rho - z) - y,
x * y - beta * z])
u0 = jnp.array([1.0, 0.0, 0.0])
tspan = (0., 100.)
t = jnp.linspace(0, 100, 1001)
sol = odeint(f, u0, t, 10.0, 28.0, 8/3, rtol=1e-8, atol=1e-8)
%timeit odeint(f, u0, t, 10.0, 28.0, 8/3, rtol=1e-8, atol=1e-8).block_until_ready()
100 loops, best of 3: 3.66 ms per loop