! pip install -U git+https://github.com/google/jax-cfd.git
Collecting git+https://github.com/google/jax-cfd.git Cloning https://github.com/google/jax-cfd.git to /tmp/pip-req-build-dd9wjtdc Running command git clone -q https://github.com/google/jax-cfd.git /tmp/pip-req-build-dd9wjtdc Requirement already satisfied: jax in /usr/local/lib/python3.7/dist-packages (from jax-cfd==0.1.0) (0.2.21) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from jax-cfd==0.1.0) (1.19.5) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jax-cfd==0.1.0) (1.4.1) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax->jax-cfd==0.1.0) (3.3.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax->jax-cfd==0.1.0) (0.12.0) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax->jax-cfd==0.1.0) (1.15.0)
import dataclasses
import jax.numpy as jnp
import jax
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.spectral import equations as spectral_equations
from jax_cfd.spectral import time_stepping
from jax_cfd.spectral import utils as spectral_utils
import numpy as np
import xarray
size = 128
length = 64.0
grid = grids.Grid((size,), domain=length)
dt = 20 * grid.step[0] / length
outer_steps = 2000
step_fn = time_stepping.crank_nicolson_rk4(
spectral_equations.KuramotoSivashinsky(grid, smooth=True), dt)
rollout_fn = funcutils.trajectory(step_fn, outer_steps)
xs, = grid.axes()
ts = dt * np.arange(outer_steps)
u0 = jnp.cos(20 * xs)
u0_hat = jnp.fft.rfft(u0)
_, trajectory_hat = jax.device_get(rollout_fn(u0_hat))
trajectory = jnp.fft.irfft(trajectory_hat).real
xarray.DataArray(trajectory, dims=['t', 'x'], coords={'x': xs, 't': ts}).plot.imshow(x='t', y='x', size=4, aspect=2)
<matplotlib.image.AxesImage at 0x7f06958e7bd0>
@dataclasses.dataclass
class AlmostKS(time_stepping.ImplicitExplicitODE):
"""Kuramoto–Sivashinsky without the fourth-order term."""
grid: grids.Grid
smooth: bool = True
def __post_init__(self):
self.kx, = self.grid.rfft_axes()
self.two_pi_i_k = 2j * jnp.pi * self.kx
self.linear_term = -self.two_pi_i_k ** 2
self.rfft = spectral_utils.truncated_rfft if self.smooth else jnp.fft.rfft
self.irfft = spectral_utils.padded_irfft if self.smooth else jnp.fft.irfft
def explicit_terms(self, uhat):
"""Non-linear parts of the equation, namely `- 1/2 * (u ** 2)_x`."""
uhat_squared = self.rfft(jnp.square(self.irfft(uhat)))
return -0.5 * self.two_pi_i_k * uhat_squared
def implicit_terms(self, uhat):
"""Linear parts of the equation, namely `- u_xx`."""
return self.linear_term * uhat
def implicit_solve(self, uhat, time_step):
"""Solves for `implicit_terms`, implicitly."""
return 1 / (1 - time_step * self.linear_term) * uhat
size = 128
length = 64.0
grid = grids.Grid((size,), domain=length)
dt = 1 * grid.step[0] / length
outer_steps = 20
step_fn = time_stepping.crank_nicolson_rk4(
AlmostKS(grid, smooth=True), dt)
rollout_fn = funcutils.trajectory(step_fn, outer_steps)
xs, = grid.axes()
ts = dt * np.arange(outer_steps)
u0 = jnp.cos(20 * xs)
u0_hat = jnp.fft.rfft(u0)
_, trajectory_hat = jax.device_get(rollout_fn(u0_hat))
trajectory = jnp.fft.irfft(trajectory_hat).real
xarray.DataArray(trajectory, dims=['t', 'x'], coords={'x': xs, 't': ts}).plot.imshow(x='t', y='x', size=4, aspect=2)
<matplotlib.image.AxesImage at 0x7f0694e80f50>