# Copyright 2020 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import functools
from jax import random
from jax import lax
import jax.numpy as jnp
import jax.ops
import jax.scipy as jsp
from jax.tree_util import Partial
import scipy.sparse.linalg
from jax.experimental import loops
def _identity(x):
return x
_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST)
def _iterative_classical_gram_schmidt(Q, x, iterations=2):
"""Orthogonalize x against the columns of Q."""
# "twice is enough"
# http://slepc.upv.es/documentation/reports/str1.pdf
q = x
r = 0
for _ in range(iterations):
h = _dot(Q.T.conj(), q)
q = q - _dot(Q, h)
r = r + h
return q, r
def arnoldi_iteration(A, b, n, M=None):
# https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration
if M is None:
M = _identity
m = b.shape[0]
q = b / jnp.linalg.norm(b)
Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)
H = jnp.zeros((n, n+1))
def f(carry, k):
Q, H = carry
q = Q[:, k]
v = A(M(q))
v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1)
v_norm = jnp.linalg.norm(v)
Q = Q.at[:, k+1].set(v / v_norm)
h = h.at[k+1].set(v_norm)
H = H.at[k, :].set(h)
return (Q, H), None
(Q, H), _ = lax.scan(f, (Q, H), jnp.arange(n))
return Q, H
@jax.jit
def lstsq(a, b):
# slightly faster than jnp.linalg.lstsq
return jsp.linalg.solve(_dot(a.T, a), _dot(a.T, b), sym_pos=True)
def _gmres(A, b, x0, n, M):
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
# TODO: exit based on acheiving some error tolerance
Q, H = arnoldi_iteration(A, b, n, M)
beta = jnp.linalg.norm(b - A(x0))
e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])
y = lstsq(H.T, beta * e1)
x = x0 + M(_dot(Q[:, :-1], y))
return x
def gmres(A, b, x0=None, n=5, M=None):
if x0 is None:
x0 = jnp.zeros_like(b)
if M is None:
M = _identity
return _gmres(A, b, x0, n, M)
Verify correctness:
A = random.normal(random.PRNGKey(0), (100, 100))
b = random.normal(random.PRNGKey(1), (100,))
np.testing.assert_allclose(
gmres(functools.partial(jnp.dot, A), b, n=20),
scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],
atol=1e-6,
)
Verify we can calculate gradients through a fixed number of loops.
(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the adjoint rule.)
@jax.grad
def loss(A, b):
return jnp.sum(gmres(functools.partial(jnp.dot, A), b))
loss(A, b)
DeviceArray([[-0.00888865, -0.0110899 , -0.01395201, ..., -0.01434979, -0.00233699, 0.0087676 ], [ 0.00685218, 0.00968965, 0.00116033, ..., -0.0108919 , -0.00220355, 0.01377206], [-0.00557139, -0.00477797, -0.01392098, ..., -0.01569233, -0.00254976, 0.01301789], ..., [-0.00446863, -0.00590283, -0.00807492, ..., -0.01217444, -0.00532266, 0.0111393 ], [ 0.00431959, 0.00333032, 0.0005375 , ..., 0.00552948, 0.00076819, -0.0026694 ], [ 0.00614914, 0.00756271, 0.0005134 , ..., -0.00826863, -0.00276195, 0.01154379]], dtype=float32)
Despite our naive implementation, out of the box performance beats SciPy by about 3x:
@functools.partial(jax.jit, static_argnums=(2,))
def explicit_gmres(A, b, n):
return gmres(functools.partial(jnp.dot, A), b, n=n)
# scipy CPU
b2 = np.asarray(b)
A2 = np.asarray(A)
%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)
1000 loops, best of 3: 1.49 ms per loop
# CPU
%timeit explicit_gmres(A, b, 30).block_until_ready()
The slowest run took 6.77 times longer than the fastest. This could mean that an intermediate result is being cached. 1000 loops, best of 3: 499 µs per loop
GPU is bit slower (for this matrix size), because there's not enough compute happening inside each loop iteration:
# GPU
%timeit explicit_gmres(A, b, 30).block_until_ready()
100 loops, best of 3: 2.66 ms per loop
We can also vmap
it! This gives us a big speed-up on GPUs:
A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))
stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))
# CPU
%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()
1 loop, best of 3: 416 ms per loop
# GPU
%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()
10 loops, best of 3: 24.5 ms per loop