# Copyright 2020 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import functools
from jax import random
from jax import lax
import jax.numpy as jnp
import jax.ops
import jax.scipy as jsp
from jax.tree_util import Partial
import scipy.sparse.linalg
def _identity(x):
return x
def _inner(v, q):
h_jk = q.conj() @ v
v = v - h_jk * q
return (v, h_jk)
def _outer(A, M, Q, k):
q = Q[:, k]
v = A(M(q))
# TODO: maybe better to use a masked dot-product rather than scan?
v, h_col = lax.scan(_inner, v, Q.T)
v_norm = jnp.linalg.norm(v)
Q = Q.at[:, k+1].set(v / v_norm)
h_col = h_col.at[k+1].set(v_norm)
return Q, h_col
def arnoldi_iteration(A, b, n, M=None):
# https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration
if M is None:
M = _identity
m = b.shape[0]
q = b / jnp.linalg.norm(b)
Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)
Q, h = lax.scan(functools.partial(_outer, A, M), Q, np.arange(n))
return Q, h.T
@jax.jit
def lstsq(a, b):
return jsp.linalg.solve(a.T @ a, a.T @ b, sym_pos=True)
def _gmres(A, b, x0, n, M):
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
Q, H = arnoldi_iteration(A, b, n, M)
beta = jnp.linalg.norm(b - A(x0))
e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])
y = lstsq(H, beta * e1)
x = x0 + M(Q[:, :-1] @ y)
return x
def gmres(A, b, x0=None, n=5, M=None):
if x0 is None:
x0 = jnp.zeros_like(b)
if M is None:
M = _identity
return _gmres(A, b, x0, n, M)
Verify correctness:
A = random.normal(random.PRNGKey(0), (100, 100))
b = random.normal(random.PRNGKey(1), (100,))
np.testing.assert_allclose(
gmres(functools.partial(jnp.dot, A), b, n=20),
scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],
atol=1e-6,
)
Verify we can calculate gradients through a fixed number of loops.
(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the adjoint rule.)
@jax.grad
def loss(A, b):
return jnp.sum(gmres(functools.partial(jnp.dot, A), b))
loss(A, b)
DeviceArray([[-0.00888863, -0.01108986, -0.01395201, ..., -0.01434983, -0.00233695, 0.0087676 ], [ 0.0068522 , 0.00968967, 0.00116034, ..., -0.0108919 , -0.00220353, 0.01377204], [-0.00557137, -0.00477795, -0.01392099, ..., -0.01569235, -0.00254974, 0.01301789], ..., [-0.00446858, -0.00590282, -0.00807489, ..., -0.01217442, -0.00532267, 0.01113929], [ 0.00431957, 0.00333034, 0.00053749, ..., 0.00552948, 0.00076819, -0.0026694 ], [ 0.00614916, 0.00756274, 0.00051342, ..., -0.00826862, -0.00276195, 0.01154379]], dtype=float32)
Despite our naive implementation, out of the box performance beats SciPy by about 2x:
@functools.partial(jax.jit, static_argnums=(2,))
def explicit_gmres(A, b, n):
return gmres(functools.partial(jnp.dot, A), b, n=n)
# CPU
%timeit explicit_gmres(A, b, 30).block_until_ready()
The slowest run took 1690.02 times longer than the fastest. This could mean that an intermediate result is being cached. 1 loop, best of 3: 649 µs per loop
# GPU
%timeit explicit_gmres(A, b, 30).block_until_ready()
10 loops, best of 3: 29.9 ms per loop
b2 = np.asarray(b)
A2 = np.asarray(A)
%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)
1000 loops, best of 3: 1.46 ms per loop
We can also vmap
it! This gives us a big speed-up on GPUs:
A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))
stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))
# CPU
%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()
1 loop, best of 3: 821 ms per loop
# GPU
%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()
10 loops, best of 3: 31.2 ms per loop