# Copyright 2020 Google LLC. # SPDX-License-Identifier: Apache-2.0 import numpy as np import functools from jax import random from jax import lax import jax.numpy as jnp import jax.ops import jax.scipy as jsp from jax.tree_util import Partial import scipy.sparse.linalg from jax.experimental import loops def _identity(x): return x _dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST) def _iterative_classical_gram_schmidt(Q, x, iterations=2): """Orthogonalize x against the columns of Q.""" # "twice is enough" # http://slepc.upv.es/documentation/reports/str1.pdf q = x r = 0 for _ in range(iterations): h = _dot(Q.T.conj(), q) q = q - _dot(Q, h) r = r + h return q, r def arnoldi_iteration(A, b, n, M=None): # https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration if M is None: M = _identity m = b.shape[0] q = b / jnp.linalg.norm(b) Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1) H = jnp.zeros((n, n+1)) def f(carry, k): Q, H = carry q = Q[:, k] v = A(M(q)) v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1) v_norm = jnp.linalg.norm(v) Q = Q.at[:, k+1].set(v / v_norm) h = h.at[k+1].set(v_norm) H = H.at[k, :].set(h) return (Q, H), None (Q, H), _ = lax.scan(f, (Q, H), jnp.arange(n)) return Q, H @jax.jit def lstsq(a, b): # slightly faster than jnp.linalg.lstsq return jsp.linalg.solve(_dot(a.T, a), _dot(a.T, b), sym_pos=True) def _gmres(A, b, x0, n, M): # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf # TODO: exit based on acheiving some error tolerance Q, H = arnoldi_iteration(A, b, n, M) beta = jnp.linalg.norm(b - A(x0)) e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))]) y = lstsq(H.T, beta * e1) x = x0 + M(_dot(Q[:, :-1], y)) return x def gmres(A, b, x0=None, n=5, M=None): if x0 is None: x0 = jnp.zeros_like(b) if M is None: M = _identity return _gmres(A, b, x0, n, M) A = random.normal(random.PRNGKey(0), (100, 100)) b = random.normal(random.PRNGKey(1), (100,)) np.testing.assert_allclose( gmres(functools.partial(jnp.dot, A), b, n=20), scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0], atol=1e-6, ) @jax.grad def loss(A, b): return jnp.sum(gmres(functools.partial(jnp.dot, A), b)) loss(A, b) @functools.partial(jax.jit, static_argnums=(2,)) def explicit_gmres(A, b, n): return gmres(functools.partial(jnp.dot, A), b, n=n) # scipy CPU b2 = np.asarray(b) A2 = np.asarray(A) %timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1) # CPU %timeit explicit_gmres(A, b, 30).block_until_ready() # GPU %timeit explicit_gmres(A, b, 30).block_until_ready() A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100)) stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,)) # CPU %timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready() # GPU %timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()