cg
for solving the Poisson equation¶# Copyright 2020 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import jax.numpy as jnp
import numpy as np
import scipy.sparse.linalg
import jax
from functools import partial
import matplotlib.pyplot as plt
def axis_slice(ndim, index, axis):
slices = [slice(None)] * ndim
slices[axis] = index
return tuple(slices)
def slice_along_axis(array, index, axis):
return array[axis_slice(array.ndim, index, axis)]
def shift(array, offset, axis):
index = slice(offset, None) if offset >= 0 else slice(None, offset)
sliced = slice_along_axis(array, index, axis)
padding = [(0, 0)] * array.ndim
padding[axis] = (-min(offset, 0), max(offset, 0))
return jnp.pad(sliced, padding, mode='constant', constant_values=0)
def laplacian(array):
# note: I believe this is faster than a convolution (at least on most platforms)
left = shift(array, +1, axis=0)
right = shift(array, -1, axis=0)
up = shift(array, +1, axis=1)
down = shift(array, -1, axis=1)
convolved = -(left + right + up + down) + 4 * array
return convolved
def laplacian_flat(array):
size = int(array.shape[0] ** 0.5)
array = array.reshape(size, size)
convolved = laplacian(array)
return convolved.reshape(-1)
def make_source(shape):
assert len(shape) == 2
x = np.linspace(0, 1, num=shape[0])
y = np.linspace(0, 1, num=shape[1])
source = np.zeros(shape)
source[0, :] = x
source[-1, :] = x
source[:, 0] = 4 * y * (1 - y)
source[:, -1] = 1 - 4 * y * (1 - y)
return source
# The functions we'll be benchmarking
def jax_poisson_cg_solve(b, x0):
solution, info = jax.scipy.sparse.linalg.cg(
laplacian, b, x0, tol=0, atol=0, maxiter=MAX_ITER)
return solution
def jax_poisson_cg_solve_flat(b, x0):
solution, info = jax.scipy.sparse.linalg.cg(
laplacian_flat, b.ravel(), x0.ravel(), tol=0, atol=0, maxiter=MAX_ITER)
return solution
# simulation parameters
MAX_ITER = 500
shape = (512, 512)
@jax.jit
def matvec(x):
return laplacian(jnp.reshape(x, shape)).ravel()
source = make_source(shape)
x0 = np.zeros(shape).ravel()
b = -source.ravel()
A = scipy.sparse.linalg.LinearOperator(
(int(np.prod(shape)),) * 2, matvec, dtype=np.float32)
solution, info = scipy.sparse.linalg.cg(A, b, x0=x0, tol=0, atol=0, maxiter=MAX_ITER)
%timeit scipy.sparse.linalg.cg(A, b, x0=x0, tol=0, atol=0, maxiter=MAX_ITER)
print(f'Error: {np.linalg.norm(matvec(solution) - b)}')
1 loop, best of 3: 2.05 s per loop Error: 0.037856701761484146
On CPU, we're about 2.5x faster than SciPy.
On GPU, 30x faster still.
for backend in ['cpu', 'gpu']:
source = make_source(shape)
b = jnp.asarray(-source).block_until_ready()
x0 = jnp.zeros_like(source).block_until_ready()
cg_solve = jax.jit(jax_poisson_cg_solve, backend=backend)
solution = cg_solve(b, x0).block_until_ready()
print(f"{backend.upper()} test:")
%timeit cg_solve(b, x0).block_until_ready()
print(f'Error: {np.linalg.norm(laplacian(solution) + source)}')
CPU test: 1 loop, best of 3: 813 ms per loop Error: 0.03532074764370918 GPU test: 10 loops, best of 3: 28.2 ms per loop Error: 0.03786032646894455
for backend in ['cpu', 'gpu']:
source = make_source(shape)
b = jnp.asarray(-source).block_until_ready()
x0 = jnp.zeros_like(source).block_until_ready()
cg_solve = jax.jit(jax_poisson_cg_solve_flat, backend=backend)
solution = cg_solve(b, x0).block_until_ready()
print(f"{backend.upper()} test:")
%timeit cg_solve(b, x0).block_until_ready()
print(f'Error: {np.linalg.norm(laplacian_flat(solution) + source.ravel())}')
CPU test: 1 loop, best of 3: 1.28 s per loop Error: 0.03532460331916809 GPU test: 10 loops, best of 3: 53.6 ms per loop Error: 0.03785990551114082
Verify that we did, indeed, approximately solve the Poisson equation
import matplotlib.pyplot as plt
plt.imshow(solution.reshape(shape))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f1e70ef0eb8>