Alex Alemi
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
import io
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, ops, lax, config
from jax import random as jr
# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display_png
mpl.rcParams['savefig.pad_inches'] = 0
plt.style.use('seaborn-dark')
%matplotlib inline
These just provide fast, better antialiased line plotting than typical matplotlib plotting routines.
@jit
def drawline(im, x0, y0, x1, y1):
"""An implementation of Wu's antialiased line algorithm.
This functional version was adapted from here:
https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm
"""
ipart = lambda x: jnp.floor(x).astype('int32')
round_ = lambda x: ipart(x + 0.5).astype('int32')
fpart = lambda x: x - jnp.floor(x)
rfpart = lambda x: 1 - fpart(x)
def plot(im, x, y, c):
return ops.index_add(im, ops.index[x, y], c)
steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)
cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)
(x0, y0) = cond_swap(steep, (x0, y0))
(x1, y1) = cond_swap(steep, (x1, y1))
(y0, y1) = cond_swap(x0 > x1, (y0, y1))
(x0, x1) = cond_swap(x0 > x1, (x0, x1))
dx = x1 - x0
dy = y1 - y0
gradient = jnp.where(dx == 0.0, 1.0, dy/dx)
# handle first endpoint
xend = round_(x0)
yend = y0 + gradient * (xend - x0)
xgap = rfpart(x0 + 0.5)
xpxl1 = xend # this will be used in main loop
ypxl1 = ipart(yend)
def true_fun(im):
im = plot(im, ypxl1, xpxl1, rfpart(yend) * xgap)
im = plot(im, ypxl1+1, xpxl1, fpart(yend) * xgap)
return im
def false_fun(im):
im = plot(im, xpxl1, ypxl1 , rfpart(yend) * xgap)
im = plot(im, xpxl1, ypxl1+1, fpart(yend) * xgap)
return im
im = lax.cond(steep, im, true_fun, im, false_fun)
intery = yend + gradient
# handle second endpoint
xend = round_(x1)
yend = y1 + gradient * (xend - x1)
xgap = fpart(x1 + 0.5)
xpxl2 = xend # this will be used in the main loop
ypxl2 = ipart(yend)
def true_fun(im):
im = plot(im, ypxl2 , xpxl2, rfpart(yend) * xgap)
im = plot(im, ypxl2+1, xpxl2, fpart(yend) * xgap)
return im
def false_fun(im):
im = plot(im, xpxl2, ypxl2, rfpart(yend) * xgap)
im = plot(im, xpxl2, ypxl2+1, fpart(yend) * xgap)
return im
im = lax.cond(steep, im, true_fun, im, false_fun)
def true_fun(arg):
im, intery = arg
def body_fun(x, arg):
im, intery = arg
im = plot(im, ipart(intery), x, rfpart(intery))
im = plot(im, ipart(intery)+1, x, fpart(intery))
intery = intery + gradient
return (im, intery)
im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))
return (im, intery)
def false_fun(arg):
im, intery = arg
def body_fun(x, arg):
im, intery = arg
im = plot(im, x, ipart(intery), rfpart(intery))
im = plot(im, x, ipart(intery)+1, fpart(intery))
intery = intery + gradient
return (im, intery)
im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))
return (im, intery)
im, intery = lax.cond(steep, (im, intery), true_fun, (im, intery), false_fun)
return im
def img_adjust(data):
oim = np.array(data)
hist, bin_edges = np.histogram(oim.flat, bins=256*256)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
cdf = hist.cumsum()
cdf = cdf / float(cdf[-1])
return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)
def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):
arr = img_adjust(arr)
sm = cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin, vmax)
if origin is None:
origin = mpl.rcParams["image.origin"]
if origin == "lower":
arr = arr[::-1]
rgba = sm.to_rgba(arr, bytes=True)
return rgba
def plot_image(array, **kwargs):
f = io.BytesIO()
imarray = imify(array, **kwargs)
plt.imsave(f, imarray, format="png")
f.seek(0)
dat = f.read()
f.close()
display_png(dat, raw=True)
def pack_images(images, rows, cols):
shape = np.shape(images)
width, height, depth = shape[-3:]
images = np.reshape(images, (-1, width, height, depth))
batch = np.shape(images)[0]
rows = np.minimum(rows, batch)
cols = np.minimum(batch // rows, cols)
images = images[:rows * cols]
images = np.reshape(images, (rows, cols, width, height, depth))
images = np.transpose(images, [0, 2, 1, 3, 4])
images = np.reshape(images, [rows * width, cols * height, depth])
return images
Implement Lorentz' attractor
sigma = 10.
beta = 8./3
rho = 28.
@jit
def f(state, t):
x, y, z = state
return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])
@jit
def rk4(ys, dt, N):
@jit
def step(i, ys):
h = dt
t = dt * i
k1 = h * f(ys[i-1], t)
k2 = h * f(ys[i-1] + k1/2., dt * i + h/2.)
k3 = h * f(ys[i-1] + k2/2., t + h/2.)
k4 = h * f(ys[i-1] + k3, t + h)
ysi = ys[i-1] + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
return ops.index_update(ys, ops.index[i], ysi)
return lax.fori_loop(1, N, step, ys)
N = 40000
# set initial condition
state0 = jnp.array([1., 1., 1.])
ys = jnp.zeros((N,) + state0.shape)
ys = ops.index_update(ys, ops.index[0], state0)
# solve for N steps
ys = rk4(ys, 0.004, N).block_until_ready()
# plotting size and region:
xlim, zlim = (-20, 20), (0, 50)
xN, zN = 800, 600
# fast, jitted plotting function
@partial(jax.jit, static_argnums=(2,3,4,5))
def jplotter(xs, zs, xlim, zlim, xN, zN):
im = jnp.zeros((xN, zN))
xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN
zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN
def body_fun(i, im):
return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])
return lax.fori_loop(1, xpixels.shape[0], body_fun, im)
im = jplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)
plot_image(im[:,::-1].T, cmap='magma')
N_dev = jax.device_count()
N = 4000
# set some initial conditions for each replicate
ys = jnp.zeros((N_dev, N, 3))
state0 = jr.uniform(jr.PRNGKey(1),
minval=-1., maxval=1.,
shape=(N_dev, 3))
state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))
ys = ops.index_update(ys, ops.index[:, 0], state0)
# solve each replicate in parallel using `pmap` of rk4 solver:
ys = jax.pmap(rk4)(ys,
0.004 * jnp.ones(N_dev),
N * jnp.ones(N_dev, dtype=np.int32)
).block_until_ready()
# parallel plotter using lexical closure and pmap'd core plotting function
def pplotter(_xs, _zs, xlim, zlim, xN, zN):
N_dev = _xs.shape[0]
im = jnp.zeros((N_dev, xN, zN))
@jax.pmap
def plotfn(im, xs, zs):
xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN
zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN
def body_fun(i, im):
return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])
return lax.fori_loop(1, xpixels.shape[0], body_fun, im)
return plotfn(im, _xs, _zs)
xlim, zlim = (-20, 20), (0, 50)
xN, zN = 200, 150
# above, plot ODE traces separately
ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)
im = pack_images(ims[..., None], 4, 2)[..., 0]
plot_image(im[:,::-1].T, cmap='magma')
# below, plot combined ODE traces
ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)
plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')