! pip install -U git+https://github.com/shoyer/jax.git@gmres-cleanup
Collecting git+https://github.com/shoyer/jax.git@gmres-cleanup Cloning https://github.com/shoyer/jax.git (to revision gmres-cleanup) to /tmp/pip-req-build-3ym4dw3w Running command git clone -q https://github.com/shoyer/jax.git /tmp/pip-req-build-3ym4dw3w Running command git checkout -b gmres-cleanup --track origin/gmres-cleanup Switched to a new branch 'gmres-cleanup' Branch 'gmres-cleanup' set up to track remote branch 'gmres-cleanup' from 'origin'. Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (1.18.5) Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (0.10.0) Requirement already satisfied, skipping upgrade: opt_einsum in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (3.3.0) Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax==0.2.6) (1.15.0) Building wheels for collected packages: jax Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.6-cp36-none-any.whl size=606591 sha256=d4a0a4550f5d613321df9534315c72dc39ae5a552163d3be201603eca460b5a7 Stored in directory: /tmp/pip-ephem-wheel-cache-cuqvuuby/wheels/99/39/0d/df246aefe5c610292921f884fdf7709e8bfb9b118f22da8c85 Successfully built jax Installing collected packages: jax Found existing installation: jax 0.2.6 Uninstalling jax-0.2.6: Successfully uninstalled jax-0.2.6 Successfully installed jax-0.2.6
import jax
import jax.config
import scipy as sp
import jax.numpy as jnp
import scipy.sparse.linalg
import numpy as np
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import time
import functools
def gmres_incremental(A, b):
f = functools.partial(jnp.dot, A)
return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0)
def gmres_direct(A, b):
f = functools.partial(jnp.dot, A)
return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0, solve_method='direct')
def gmres_scipy(A, b):
return scipy.sparse.linalg.gmres(A_, b_, restart=restart, maxiter=1, atol=0, tol=0)
for N, restart in [
(20, 10),
(200, 50),
(2000, 200),
]:
print(f"\nN={N}, restart={restart}")
A = jax.random.normal(jax.random.PRNGKey(0), (N, N))
b = jax.random.normal(jax.random.PRNGKey(1), (N,))
A_, b_ = np.asarray(A), np.asarray(b)
print("SciPy CPU:")
%timeit gmres_scipy(A_, b_)
print("JAX incremental CPU:")
gmres_ = jax.jit(gmres_incremental, backend='cpu')
gmres_(A, b)[0].block_until_ready()
%timeit gmres_(A, b)[0].block_until_ready()
print("JAX direct CPU:")
gmres_ = jax.jit(gmres_direct, backend='cpu')
gmres_(A, b)[0].block_until_ready()
%timeit gmres_(A, b)[0].block_until_ready()
print("JAX incremental GPU:")
gmres_ = jax.jit(gmres_incremental, backend='gpu')
gmres_(A, b)[0].block_until_ready()
%timeit gmres_(A, b)[0].block_until_ready()
print("JAX direct GPU:")
gmres_ = jax.jit(gmres_direct, backend='gpu')
gmres_(A, b)[0].block_until_ready()
%timeit gmres_(A, b)[0].block_until_ready()
N=20, restart=10 SciPy CPU: 1000 loops, best of 3: 355 µs per loop JAX incremental CPU: 10000 loops, best of 3: 108 µs per loop JAX direct CPU: 10000 loops, best of 3: 112 µs per loop JAX incremental GPU: 100 loops, best of 3: 4.83 ms per loop JAX direct GPU: 100 loops, best of 3: 1.92 ms per loop N=200, restart=50 SciPy CPU: 100 loops, best of 3: 2.63 ms per loop JAX incremental CPU: 1000 loops, best of 3: 762 µs per loop JAX direct CPU: 1000 loops, best of 3: 1.05 ms per loop JAX incremental GPU: 10 loops, best of 3: 67.6 ms per loop JAX direct GPU: 100 loops, best of 3: 8.77 ms per loop N=2000, restart=100 SciPy CPU: 10 loops, best of 3: 163 ms per loop JAX incremental CPU: 10 loops, best of 3: 174 ms per loop JAX direct CPU: 10 loops, best of 3: 176 ms per loop JAX incremental GPU: 1 loop, best of 3: 252 ms per loop JAX direct GPU: 10 loops, best of 3: 38.5 ms per loop