In this notebook, we solve the 2D wave equation: $$ \frac{\partial^2 u}{\partial t^2} = c^2 \nabla^2 u $$
We use a simple finite difference formulation with Leapfrog time integration.
Note: It is natural to express finite difference methods as convolutions, but here we intentionally avoid convolutions in favor of array indexing/arithmetic. This is because "batch" and "feature" dimensions in TPU convolutions are padded to multiples of either 8 and 128, but in our case both these dimensions are effectively of size 1.
# Grab other packages for this demo.
!pip install -U -q Pillow moviepy proglog scikit-image
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
from functools import partial
import jax
from jax import jit, pmap
from jax import lax
from jax import tree_util
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import skimage.filters
import proglog
from moviepy.editor import ImageSequenceClip
device_count = jax.device_count()
# Spatial partitioning via halo exchange
def send_right(x, axis_name):
# Note: if some devices are omitted from the permutation, lax.ppermute
# provides zeros instead. This gives us an easy way to apply Dirichlet
# boundary conditions.
left_perm = [(i, (i + 1) % device_count) for i in range(device_count - 1)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def send_left(x, axis_name):
left_perm = [((i + 1) % device_count, i) for i in range(device_count - 1)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def axis_slice(ndim, index, axis):
slices = [slice(None)] * ndim
slices[axis] = index
return tuple(slices)
def slice_along_axis(array, index, axis):
return array[axis_slice(array.ndim, index, axis)]
def tree_vectorize(func):
def wrapper(x, *args, **kwargs):
return tree_util.tree_map(lambda x: func(x, *args, **kwargs), x)
return wrapper
@tree_vectorize
def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):
if not padding > 0:
raise ValueError(f'invalid padding: {padding}')
array = jnp.array(array)
if array.ndim == 0:
return array
left = slice_along_axis(array, slice(None, padding), axis)
right = slice_along_axis(array, slice(-padding, None), axis)
right, left = send_left(left, axis_name), send_right(right, axis_name)
return jnp.concatenate([left, array, right], axis)
@tree_vectorize
def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):
left = slice_along_axis(array, slice(padding, 2*padding), axis)
right = slice_along_axis(array, slice(-2*padding, -padding), axis)
right, left = send_left(left, axis_name), send_right(right, axis_name)
array = jax.ops.index_update(
array, axis_slice(array.ndim, slice(None, padding), axis), left)
array = jax.ops.index_update(
array, axis_slice(array.ndim, slice(-padding, None), axis), right)
return array
# Reshaping inputs/outputs for pmap
def split_with_reshape(array, num_splits, *, split_axis=0, tile_id_axis=None):
if tile_id_axis is None:
tile_id_axis = split_axis
tile_size, remainder = divmod(array.shape[split_axis], num_splits)
if remainder:
raise ValueError('num_splits must equally divide the dimension size')
new_shape = list(array.shape)
new_shape[split_axis] = tile_size
new_shape.insert(split_axis, num_splits)
return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)
def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):
if tile_id_axis is None:
tile_id_axis = split_axis
array = jnp.moveaxis(array, tile_id_axis, split_axis)
new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]
return jnp.reshape(array, new_shape)
def shard(func):
def wrapper(state):
sharded_state = tree_util.tree_map(
lambda x: split_with_reshape(x, device_count), state)
sharded_result = func(sharded_state)
result = tree_util.tree_map(stack_with_reshape, sharded_result)
return result
return wrapper
# Physics
def shift(array, offset, axis):
index = slice(offset, None) if offset >= 0 else slice(None, offset)
sliced = slice_along_axis(array, index, axis)
padding = [(0, 0)] * array.ndim
padding[axis] = (-min(offset, 0), max(offset, 0))
return jnp.pad(sliced, padding, mode='constant', constant_values=0)
def laplacian(array, step=1):
left = shift(array, +1, axis=0)
right = shift(array, -1, axis=0)
up = shift(array, +1, axis=1)
down = shift(array, -1, axis=1)
convolved = (left + right + up + down - 4 * array)
if step != 1:
convolved *= (1 / step ** 2)
return convolved
def scalar_wave_equation(u, c=1, dx=1):
return c ** 2 * laplacian(u, dx)
@jax.jit
def leapfrog_step(state, dt=0.5, c=1):
# https://en.wikipedia.org/wiki/Leapfrog_integration
u, u_t = state
u_tt = scalar_wave_equation(u, c)
u_t = u_t + u_tt * dt
u = u + u_t * dt
return (u, u_t)
# Time stepping
def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):
return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)
def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,
save_interval=1):
def exchange_and_multi_step(state_padded):
c_padded = halo_exchange_padding(c, exchange_interval)
evolved = multi_step(state_padded, exchange_interval, dt, c_padded)
return halo_exchange_inplace(evolved, exchange_interval)
@shard
@partial(jax.pmap, axis_name='x')
def simulate_until_output(state):
stop = save_interval // exchange_interval
state_padded = halo_exchange_padding(state, exchange_interval)
advanced = lax.fori_loop(
0, stop, lambda i, s: exchange_and_multi_step(s), state_padded)
xi = exchange_interval
return tree_util.tree_map(lambda array: array[xi:-xi, ...], advanced)
results = [state]
for _ in range(count // save_interval):
state = simulate_until_output(state)
tree_util.tree_map(lambda x: x.copy_to_host_async(), state)
results.append(state)
results = jax.device_get(results)
return tree_util.tree_multimap(lambda *xs: np.stack([np.array(x) for x in xs]), *results)
multi_step_jit = jax.jit(multi_step)
x = jnp.linspace(0, 8, num=8*1024, endpoint=False)
y = jnp.linspace(0, 1, num=1*1024, endpoint=False)
x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')
# NOTE: smooth initial conditions are important, so we aren't exciting
# arbitrarily high frequencies (that cannot be resolved)
u = skimage.filters.gaussian(
((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,
sigma=1)
# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)
# u = skimage.filters.gaussian(
# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),
# sigma=5)
v = jnp.zeros_like(u)
c = 1 # could also use a 2D array matching the mesh shape
u.shape
%%time
# single TPU chip
u_final, _ = multi_step_jit((u, v), count=2**13, c=c, dt=0.5)
%%time
# 8x TPU chips, 4x more steps in roughly half the time!
u_final, _ = multi_step_pmap(
(u, v), count=2**15, c=c, dt=0.5, exchange_interval=4, save_interval=2**15)
18.3 / (10.3 / 4) # near linear scaling (8x would be perfect)
%%time
# save more outputs for a movie -- this is slow!
u_final, _ = multi_step_pmap(
(u, v), count=2**15, c=c, dt=0.2, exchange_interval=4, save_interval=2**10)
u_final.shape
u_final.nbytes / 1e9
plt.figure(figsize=(18, 6))
plt.axis('off')
plt.imshow(u_final[-1].T, cmap='RdBu');
fig, axes = plt.subplots(9, 1, figsize=(14, 14))
[ax.axis('off') for ax in axes]
axes[0].imshow(u_final[0].T, cmap='RdBu', aspect='equal', vmin=-1, vmax=1)
for i in range(8):
axes[i+1].imshow(u_final[4*i+1].T / abs(u_final[4*i+1]).max(), cmap='RdBu', aspect='equal', vmin=-1, vmax=1)
import matplotlib.cm
import matplotlib.colors
from PIL import Image
def make_images(data, cmap='RdBu', vmax=None):
images = []
for frame in data:
if vmax is None:
this_vmax = np.max(abs(frame))
else:
this_vmax = vmax
norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
rgba = mappable.to_rgba(frame, bytes=True)
image = Image.fromarray(rgba, mode='RGBA')
images.append(image)
return images
def save_movie(images, path, duration=100, loop=0, **kwargs):
images[0].save(path, save_all=True, append_images=images[1:],
duration=duration, loop=loop, **kwargs)
images = make_images(u_final[::, ::8, ::8].transpose(0, 2, 1))
# Show Movie
proglog.default_bar_logger = partial(proglog.default_bar_logger, None)
ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()
# Save GIF.
save_movie(images,'wave_movie.gif', duration=[2000]+[200]*(len(images)-2)+[2000])
# The movie sometimes takes a second before showing up in the file system.
import time; time.sleep(1)
# Download animation.
try:
from google.colab import files
except ImportError:
pass
else:
files.download('wave_movie.gif')