In this notebook, we illustrate physics informed neural networks (PINNs) to solve partial differential equations (PDEs) as proposed in
This notebook is partially based on another implementation of the PINN approach published on GitHub by pierremtb as well as the original code, see Maziar Raissi on GitHub.
We describe the PINN approach for approximating the solution $u:[0,T] \times \mathcal{D} \to \mathbb{R}$ of an evolution equation
$$ \begin{align} \partial_t u (t,x) + \mathcal{N}[u](t,x) &= 0, && (t,x) \in (0,T] \times \mathcal{D},\\ u(0,x) &= u_0(x) \quad && x \in \mathcal{D}, \end{align} $$where $\mathcal{N}$ is a nonlinear differential operator acting on $u$, $\mathcal{D} \subset \mathbb{R}^d$ a bounded domain, $T$ denotes the final time and $u_0: \mathcal{D} \to \mathbb{R}$ the prescribed initial data. Although the methodology allows for different types of boundary conditions, we restrict our discussion to the inhomogeneous Dirichlet case and prescribe
$$ \begin{align} \hspace{7em} u(t,x) &= u_b(t,x) && \quad (t,x) \in (0,T] \times \partial \mathcal{D}, \end{align} $$where $\partial \mathcal{D}$ denotes the boundary of the domain $\mathcal{D}$ and $u_b: (0,T] \times \partial \mathcal{D} \to \mathbb{R}$ the given boundary data.
The method constructs a neural network approximation
$$ u_\theta(t,x) \approx u(t,x) $$of the solution of nonlinear PDE, where $u_\theta :[0,T] \times \mathcal{D} \to \mathbb{R}$ denotes a function realized by a neural network with parameters $\theta$.
The continuous time approach for the parabolic PDE as described in (Raissi et al., 2017 (Part I)) is based on the (strong) residual of a given neural network approximation $u_\theta \colon [0,T] \times \mathcal{D} \to \mathbb{R} $ of the solution $u$, i.e.,
$$ \begin{align} r_\theta (t,x) := \partial_t u_\theta (t,x) + \mathcal{N}[u_\theta] (t,x). \end{align} $$To incorporate this PDE residual $r_\theta$ into a loss function to be minimized, PINNs require a further differentiation to evaluate the differential operators $\partial_t u_\theta$ and $\mathcal{N}[u_\theta]$. Thus the PINN term $r_\theta$ shares the same parameters as the original network $u_\theta(t,x)$, but respects the underlying "physics" of the nonlinear PDE. Both types of derivatives can be easily determined through automatic differentiation with current state-of-the-art machine learning libraries, e.g., TensorFlow or PyTorch.
The PINN approach for the solution of the initial and boundary value problem now proceeds by minimization of the loss functional
$$ \begin{align} \phi_\theta(X) := \phi_\theta^r(X^r) + \phi_\theta^0(X^0) + \phi_\theta^b(X^b), \end{align} $$where $X$ denotes the collection of training data and the loss function $\phi_\theta$ contains the following terms:
in a number of collocation points $X^r:=\{(t_i^r, x_i^r)\}_{i=1}^{N_r} \subset (0,T] \times \mathcal{D}$, where $r_\theta$ is the physics-informed neural network,
in a number of points $X^0:=\{(t^0_i,x^0_i)\}_{i=1}^{N_0} \subset \{0\} \times \mathcal{D}$ and $X^b:=\{(t^b_i,x^b_i)\}_{i=1}^{N_b} \subset (0,T] \times \partial \mathcal{D}$, where $u_\theta$ is the neural network approximation of the solution $u\colon[0,T] \times \mathcal{D} \to \mathbb{R}$.
Note that the training data $X$ consists entirely of time-space coordinates.
To illustrate the PINN approach we consider the one-dimensional Burgers equation on the spatial domain $\mathcal{D} = [-1,1]$
$$ \begin{aligned} \partial_t u + u \, \partial_x u - (0.01/\pi) \, \partial_{xx} u &= 0, \quad &&\quad (t,x) \in (0,1] \times (-1,1),\\ u(0,x) &= - \sin(\pi \, x), \quad &&\quad x \in [-1,1],\\ u(t,-1) = u(t,1) &= 0, \quad &&\quad t \in (0,1]. \end{aligned} $$This PDE arises in various disciplines such as traffic flow, fluid mechanics and gas dynamics, and can be derived from the Navier-Stokes equations, see (Basdevant et al., 1986).
This code runs with TensorFlow version 2.3.0
.
The implementation relies mainly on the scientific computing library NumPy and the machine learning library TensorFlow.
All computations were performed on an Intel i7 CPU (8th Gen) with 16 GByte DDR3 RAM (2133 MHz) within a couple of minutes.
%config InlineBackend.figure_format = 'svg'
# Import TensorFlow and NumPy
import tensorflow as tf
import numpy as np
# Set data type
DTYPE='float32'
tf.keras.backend.set_floatx(DTYPE)
# Set constants
pi = tf.constant(np.pi, dtype=DTYPE)
viscosity = .01/pi
# Define initial condition
def fun_u_0(x):
return -tf.sin(pi * x)
# Define boundary condition
def fun_u_b(t, x):
n = x.shape[0]
return tf.zeros((n,1), dtype=DTYPE)
# Define residual of the PDE
def fun_r(t, x, u, u_t, u_x, u_xx):
return u_t + u * u_x - viscosity * u_xx
We assume that the collocation points $X_r$ as well as the points for the initial time and boundary data $X_0$ and $X_b$ are generated by random sampling from a uniform distribution. Although uniformly distributed data are sufficient in our experiments, the authors of (Raissi et al., 2017 (Part I)) employed a space-filling Latin hypercube sampling strategy (Stein, 1987). Our numerical experiments indicate that this strategy slightly improves the observed convergence rate, but for simplicity the code examples accompanying this paper employ uniform sampling throughout. We choose training data of size $N_0 = N_b =50$ and $N_f=10000$.
# Set number of data points
N_0 = 50
N_b = 50
N_r = 10000
# Set boundary
tmin = 0.
tmax = 1.
xmin = -1.
xmax = 1.
# Lower bounds
lb = tf.constant([tmin, xmin], dtype=DTYPE)
# Upper bounds
ub = tf.constant([tmax, xmax], dtype=DTYPE)
# Set random seed for reproducible results
tf.random.set_seed(0)
# Draw uniform sample points for initial boundary data
t_0 = tf.ones((N_0,1), dtype=DTYPE)*lb[0]
x_0 = tf.random.uniform((N_0,1), lb[1], ub[1], dtype=DTYPE)
X_0 = tf.concat([t_0, x_0], axis=1)
# Evaluate intitial condition at x_0
u_0 = fun_u_0(x_0)
# Boundary data
t_b = tf.random.uniform((N_b,1), lb[0], ub[0], dtype=DTYPE)
x_b = lb[1] + (ub[1] - lb[1]) * tf.keras.backend.random_bernoulli((N_b,1), 0.5, dtype=DTYPE)
X_b = tf.concat([t_b, x_b], axis=1)
# Evaluate boundary condition at (t_b,x_b)
u_b = fun_u_b(t_b, x_b)
# Draw uniformly sampled collocation points
t_r = tf.random.uniform((N_r,1), lb[0], ub[0], dtype=DTYPE)
x_r = tf.random.uniform((N_r,1), lb[1], ub[1], dtype=DTYPE)
X_r = tf.concat([t_r, x_r], axis=1)
# Collect boundary and inital data in lists
X_data = [X_0, X_b]
u_data = [u_0, u_b]
Next, we illustrate the collocation points (red circles) and the positions where the boundary and initial conditions will be enforced (cross marks, color indicates value).
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(9,6))
plt.scatter(t_0, x_0, c=u_0, marker='X', vmin=-1, vmax=1)
plt.scatter(t_b, x_b, c=u_b, marker='X', vmin=-1, vmax=1)
plt.scatter(t_r, x_r, c='r', marker='.', alpha=0.1)
plt.xlabel('$t$')
plt.ylabel('$x$')
plt.title('Positions of collocation points and boundary data');
#plt.savefig('Xdata_Burgers.pdf', bbox_inches='tight', dpi=300)
In this example, adopted from (Raissi et al., 2017 (Part I)), we assume a feedforward neural network of the following structure:
This setting results in a network with $3021$ trainable parameters (first hidden layer: $2 \cdot 20 + 20 = 60$; seven intermediate layers: each $20 \cdot 20 + 20 = 420$; output layer: $20 \cdot 1 + 1 = 21$).
def init_model(num_hidden_layers=8, num_neurons_per_layer=20):
# Initialize a feedforward neural network
model = tf.keras.Sequential()
# Input is two-dimensional (time + one spatial dimension)
model.add(tf.keras.Input(2))
# Introduce a scaling layer to map input to [lb, ub]
scaling_layer = tf.keras.layers.Lambda(
lambda x: 2.0*(x - lb)/(ub - lb) - 1.0)
model.add(scaling_layer)
# Append hidden layers
for _ in range(num_hidden_layers):
model.add(tf.keras.layers.Dense(num_neurons_per_layer,
activation=tf.keras.activations.get('tanh'),
kernel_initializer='glorot_normal'))
# Output is one-dimensional
model.add(tf.keras.layers.Dense(1))
return model
In the following code cell, we define a function which evaluates the residual
$$ \begin{align} r_\theta (t,x) := \partial_t u_\theta (t,x) + \mathcal{N}[u_\theta] (t,x). \end{align} $$of the nonlinear PDE in the points $X_r = \{(t^r_i,x^r_i)\}_{i=1}^{N_r}$. To compute the necessary partial derivatives we use the automatic differentiation capabilities of TensorFlow.
For the Burgers equation, this entails computing $\partial_t u_\theta$, $\partial_x u_\theta$ and $\partial_{xx} u_\theta$.
In TensorFlow, this is done via a GradientTape
, see also the documentation, which keeps track of the watched
variables, in our case t
and x
, in order to compute the derivatives.
def get_r(model, X_r):
# A tf.GradientTape is used to compute derivatives in TensorFlow
with tf.GradientTape(persistent=True) as tape:
# Split t and x to compute partial derivatives
t, x = X_r[:, 0:1], X_r[:,1:2]
# Variables t and x are watched during tape
# to compute derivatives u_t and u_x
tape.watch(t)
tape.watch(x)
# Determine residual
u = model(tf.stack([t[:,0], x[:,0]], axis=1))
# Compute gradient u_x within the GradientTape
# since we need second derivatives
u_x = tape.gradient(u, x)
u_t = tape.gradient(u, t)
u_xx = tape.gradient(u_x, x)
del tape
return fun_r(t, x, u, u_t, u_x, u_xx)
The next function computes the loss for our model
$$ \begin{align} \phi_\theta(X) := \phi_\theta^r(X^r) + \phi_\theta^0(X^0) + \phi_\theta^b(X^b), \end{align} $$as a function of our the training data.
The collocation points are given by X_r
, the initial and boundary data is contained in X_data = [X_0, X_b]
and u_data = [u_0, u_b]
.
def compute_loss(model, X_r, X_data, u_data):
# Compute phi^r
r = get_r(model, X_r)
phi_r = tf.reduce_mean(tf.square(r))
# Initialize loss
loss = phi_r
# Add phi^0 and phi^b to the loss
for i in range(len(X_data)):
u_pred = model(X_data[i])
loss += tf.reduce_mean(tf.square(u_data[i] - u_pred))
return loss
The next function computes the gradient of the loss function $\phi_\theta$ with respect to the unknown variables in the model, also called trainable_variables
in TensorFlow, i.e. $\nabla_\theta \phi_\theta$.
This is also done via a GradientTape
, but now it keeps track of the parameters $\theta$ in our model, which can be accessed by model.trainable_variables
.
def get_grad(model, X_r, X_data, u_data):
with tf.GradientTape(persistent=True) as tape:
# This tape is for derivatives with
# respect to trainable variables
tape.watch(model.trainable_variables)
loss = compute_loss(model, X_r, X_data, u_data)
g = tape.gradient(loss, model.trainable_variables)
del tape
return loss, g
Next we initialize the model, set the learning rate to the step function
$$ \delta(n) = 0.01 \, \textbf{1}_{\{n < 1000\}} + 0.001 \, \textbf{1}_{\{1000 \le n < 3000\}} + 0.0005 \, \textbf{1}_{\{3000 \le n\}} $$which decays in a piecewise constant fashion, and set up a tf.keras.optimizer
to train the model.
# Initialize model aka u_\theta
model = init_model()
model.summary()
# We choose a piecewise decay of the learning rate, i.e., the
# step size in the gradient descent type algorithm
# the first 1000 steps use a learning rate of 0.01
# from 1000 - 3000: learning rate = 0.001
# from 3000 onwards: learning rate = 0.0005
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([1000,3000],[1e-2,1e-3,5e-4])
# Choose the optimizer
optim = tf.keras.optimizers.Adam(learning_rate=lr)
Train the model for $N=5000$ epochs (takes approximately 3 minutes).
Here, we set up a function train_step()
which performs one training step.
Note: The @tf.function
is a so-called Decorator
within Python. This particular decorator redefines the function that follows, in our case train_step
, as a TensorFlow graph which may speed up the training significantly.
from time import time
# Define one training step as a TensorFlow function to increase speed of training
@tf.function
def train_step():
# Compute current loss and gradient w.r.t. parameters
loss, grad_theta = get_grad(model, X_r, X_data, u_data)
# Perform gradient descent step
optim.apply_gradients(zip(grad_theta, model.trainable_variables))
return loss
# Number of training epochs
N = 5000
hist = []
# Start timer
t0 = time()
for i in range(N+1):
loss = train_step()
# Append current loss to hist
hist.append(loss.numpy())
# Output current loss after 50 iterates
if i%500 == 0:
print('It {:05d}: loss = {:10.8e}'.format(i,loss))
# Print computation time
print('\nComputation time: {} seconds'.format(time()-t0))
from mpl_toolkits.mplot3d import Axes3D
# Set up meshgrid
N = 600
tspace = np.linspace(lb[0], ub[0], N + 1)
xspace = np.linspace(lb[1], ub[1], N + 1)
T, X = np.meshgrid(tspace, xspace)
Xgrid = np.vstack([T.flatten(),X.flatten()]).T
# Determine predictions of u(t, x)
upred = model(tf.cast(Xgrid,DTYPE))
# Reshape upred
U = upred.numpy().reshape(N+1,N+1)
# Surface plot of solution u(t,x)
fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(T, X, U, cmap='viridis');
ax.view_init(35,35)
ax.set_xlabel('$t$')
ax.set_ylabel('$x$')
ax.set_zlabel('$u_\\theta(t,x)$')
ax.set_title('Solution of Burgers equation');
#plt.savefig('Burgers_Solution.pdf', bbox_inches='tight', dpi=300);
plt.plot(X[:,0], U[:,0], label='Initial condition')
plt.plot(X[:,0], U[:,-1], label='Final solution')
plt.xlabel('x'); plt.ylabel('u')
plt.legend();
fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(111)
ax.semilogy(range(len(hist)), hist,'k-')
ax.set_xlabel('$n_{epoch}$')
ax.set_ylabel('$\\phi_{n_{epoch}}$');
In this section, we implement PINNs as a class which can be used for further testing. Here, we derive the class PINN_NeuralNet
from tf.keras.Model
.
Required arguments are the lower bound lb
and upper bound ub
.
# Define model architecture
class PINN_NeuralNet(tf.keras.Model):
""" Set basic architecture of the PINN model."""
def __init__(self, lb, ub,
output_dim=1,
num_hidden_layers=8,
num_neurons_per_layer=20,
activation='tanh',
kernel_initializer='glorot_normal',
**kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
self.output_dim = output_dim
self.lb = lb
self.ub = ub
# Define NN architecture
self.scale = tf.keras.layers.Lambda(
lambda x: 2.0*(x - lb)/(ub - lb) - 1.0)
self.hidden = [tf.keras.layers.Dense(num_neurons_per_layer,
activation=tf.keras.activations.get(activation),
kernel_initializer=kernel_initializer)
for _ in range(self.num_hidden_layers)]
self.out = tf.keras.layers.Dense(output_dim)
def call(self, X):
"""Forward-pass through neural network."""
Z = self.scale(X)
for i in range(self.num_hidden_layers):
Z = self.hidden[i](Z)
return self.out(Z)
Next, we derive a class PINNSolver
which can be used as a base class.
It possesses two methods to solve the PDE:
solve_with_TFoptimizer
uses a TensorFlow
optimizer object as input, e.g., the AdamOptimizer
above;solve_with_LBFGS
resembles the LBFGS method proposed in the original paper using an LBFGS method provided by SciPy
.import scipy.optimize
class PINNSolver():
def __init__(self, model, X_r):
self.model = model
# Store collocation points
self.t = X_r[:,0:1]
self.x = X_r[:,1:2]
# Initialize history of losses and global iteration counter
self.hist = []
self.iter = 0
def get_r(self):
with tf.GradientTape(persistent=True) as tape:
# Watch variables representing t and x during this GradientTape
tape.watch(self.t)
tape.watch(self.x)
# Compute current values u(t,x)
u = self.model(tf.stack([self.t[:,0], self.x[:,0]], axis=1))
u_x = tape.gradient(u, self.x)
u_t = tape.gradient(u, self.t)
u_xx = tape.gradient(u_x, self.x)
del tape
return self.fun_r(self.t, self.x, u, u_t, u_x, u_xx)
def loss_fn(self, X, u):
# Compute phi_r
r = self.get_r()
phi_r = tf.reduce_mean(tf.square(r))
# Initialize loss
loss = phi_r
# Add phi_0 and phi_b to the loss
for i in range(len(X)):
u_pred = self.model(X[i])
loss += tf.reduce_mean(tf.square(u[i] - u_pred))
return loss
def get_grad(self, X, u):
with tf.GradientTape(persistent=True) as tape:
# This tape is for derivatives with
# respect to trainable variables
tape.watch(self.model.trainable_variables)
loss = self.loss_fn(X, u)
g = tape.gradient(loss, self.model.trainable_variables)
del tape
return loss, g
def fun_r(self, t, x, u, u_t, u_x, u_xx):
"""Residual of the PDE"""
return u_t + u * u_x - viscosity * u_xx
def solve_with_TFoptimizer(self, optimizer, X, u, N=1001):
"""This method performs a gradient descent type optimization."""
@tf.function
def train_step():
loss, grad_theta = self.get_grad(X, u)
# Perform gradient descent step
optimizer.apply_gradients(zip(grad_theta, self.model.trainable_variables))
return loss
for i in range(N):
loss = train_step()
self.current_loss = loss.numpy()
self.callback()
def solve_with_ScipyOptimizer(self, X, u, method='L-BFGS-B', **kwargs):
"""This method provides an interface to solve the learning problem
using a routine from scipy.optimize.minimize.
(Tensorflow 1.xx had an interface implemented, which is not longer
supported in Tensorflow 2.xx.)
Type conversion is necessary since scipy-routines are written in Fortran
which requires 64-bit floats instead of 32-bit floats."""
def get_weight_tensor():
"""Function to return current variables of the model
as 1d tensor as well as corresponding shapes as lists."""
weight_list = []
shape_list = []
# Loop over all variables, i.e. weight matrices, bias vectors and unknown parameters
for v in self.model.variables:
shape_list.append(v.shape)
weight_list.extend(v.numpy().flatten())
weight_list = tf.convert_to_tensor(weight_list)
return weight_list, shape_list
x0, shape_list = get_weight_tensor()
def set_weight_tensor(weight_list):
"""Function which sets list of weights
to variables in the model."""
idx = 0
for v in self.model.variables:
vs = v.shape
# Weight matrices
if len(vs) == 2:
sw = vs[0]*vs[1]
new_val = tf.reshape(weight_list[idx:idx+sw],(vs[0],vs[1]))
idx += sw
# Bias vectors
elif len(vs) == 1:
new_val = weight_list[idx:idx+vs[0]]
idx += vs[0]
# Variables (in case of parameter identification setting)
elif len(vs) == 0:
new_val = weight_list[idx]
idx += 1
# Assign variables (Casting necessary since scipy requires float64 type)
v.assign(tf.cast(new_val, DTYPE))
def get_loss_and_grad(w):
"""Function that provides current loss and gradient
w.r.t the trainable variables as vector. This is mandatory
for the LBFGS minimizer from scipy."""
# Update weights in model
set_weight_tensor(w)
# Determine value of \phi and gradient w.r.t. \theta at w
loss, grad = self.get_grad(X, u)
# Store current loss for callback function
loss = loss.numpy().astype(np.float64)
self.current_loss = loss
# Flatten gradient
grad_flat = []
for g in grad:
grad_flat.extend(g.numpy().flatten())
# Gradient list to array
grad_flat = np.array(grad_flat,dtype=np.float64)
# Return value and gradient of \phi as tuple
return loss, grad_flat
return scipy.optimize.minimize(fun=get_loss_and_grad,
x0=x0,
jac=True,
method=method,
callback=self.callback,
**kwargs)
def callback(self, xr=None):
if self.iter % 50 == 0:
print('It {:05d}: loss = {:10.8e}'.format(self.iter,self.current_loss))
self.hist.append(self.current_loss)
self.iter+=1
def plot_solution(self, **kwargs):
N = 600
tspace = np.linspace(self.model.lb[0], self.model.ub[0], N+1)
xspace = np.linspace(self.model.lb[1], self.model.ub[1], N+1)
T, X = np.meshgrid(tspace, xspace)
Xgrid = np.vstack([T.flatten(),X.flatten()]).T
upred = self.model(tf.cast(Xgrid,DTYPE))
U = upred.numpy().reshape(N+1,N+1)
fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(T, X, U, cmap='viridis', **kwargs)
ax.set_xlabel('$t$')
ax.set_ylabel('$x$')
ax.set_zlabel('$u_\\theta(t,x)$')
ax.view_init(35,35)
return ax
def plot_loss_history(self, ax=None):
if not ax:
fig = plt.figure(figsize=(7,5))
ax = fig.add_subplot(111)
ax.semilogy(range(len(self.hist)), self.hist,'k-')
ax.set_xlabel('$n_{epoch}$')
ax.set_ylabel('$\\phi^{n_{epoch}}$')
return ax
The following code cell shows how the new classes PINN_NeuralNet
and PINNSolver
can be used to solve the Burgers equation, this time using the SciPy
implementation of L-BFGS (takes around 3 minutes).
# Initialize model
model = PINN_NeuralNet(lb, ub)
model.build(input_shape=(None,2))
# Initilize PINN solver
solver = PINNSolver(model, X_r)
# Decide which optimizer should be used
#mode = 'TFoptimizer'
mode = 'ScipyOptimizer'
# Start timer
t0 = time()
if mode == 'TFoptimizer':
# Choose optimizer
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([1000,3000],[1e-2,1e-3,5e-4])
optim = tf.keras.optimizers.Adam(learning_rate=lr)
solver.solve_with_TFoptimizer(optim, X_data, u_data, N=4001)
elif mode == 'ScipyOptimizer':
solver.solve_with_ScipyOptimizer(X_data, u_data,
method='L-BFGS-B',
options={'maxiter': 50000,
'maxfun': 50000,
'maxcor': 50,
'maxls': 50,
'ftol': 1.0*np.finfo(float).eps})
# Print computation time
print('\nComputation time: {} seconds'.format(time()-t0))
Plot solution and loss history.
solver.plot_solution();
solver.plot_loss_history();
As a second example we consider the one-dimensional Eikonal equation backward in time on the domain $\mathcal{D}=[-1,1]$
$$ \begin{align} -\partial_t u(t,x) + |\nabla u|(t,x) &= 1, \quad & &(t,x) \in [0,T) \times [-1,1],\\ u(T,x) &= 0, \quad & &x \in [-1,1],\\ u(t,-1) = u(t, 1) &= 0, \quad & & t \in [0,T). \end{align} $$Note that the partial differential equation can be equally written as a Hamilton-Jacobi-Bellman equation, viz
$$ -\partial_t u(t,x) + \sup_{|c| \le 1} \{c \, \nabla u(t,x)\} = 1 \quad (t,x) \in [0,T) \times [-1,1], $$which characterizes the solution of an optimal control problem seeking to minimize the distance from a point $(t,x)$ to the boundary $[0,T] \times \partial \mathcal{D} \cup \{T\} \times \mathcal{D}$. As is easily verified, the solution is given by $u(t,x) = \min\{ 1 - t, 1 - |x| \}$. The fact that the Eikonal equation runs backward in time is in accordance with its interpretation as the optimality condition of a control problem. Note that this is equation can be transformed into a forward evolution problem by the change of variables $\hat t = T - t$.
N_0 = 50
N_b = 50
N_r = 10000
# Specify boundaries
lb = tf.constant([0., -1.], dtype=DTYPE)
ub = tf.constant([1., 1.], dtype=DTYPE)
def Eikonal_u_0(x):
n = x.shape[0]
return tf.zeros((n,1), dtype=DTYPE)
def Eikonal_u_b(t, x):
n = x.shape[0]
return tf.zeros((n,1), dtype=DTYPE)
This code snippet is almost identical to the code from above.
We choose $N_b = 50$ and $N_0 = 50$ uniformly distributed initial value and boundary points and sample $N_r = 10000$ collocation points uniformly within the domain boundaries.
We derive a new solver with PINNSolver
as base class.
tf.random.set_seed(0)
# Final time data
t_0 = tf.ones((N_0,1), dtype=DTYPE) * ub[0]
x_0 = tf.random.uniform((N_0,1), lb[1], ub[1], dtype=DTYPE)
X_0 = tf.concat([t_0, x_0], axis=1)
u_0 = Eikonal_u_0(x_0)
# Boundary data
t_b = tf.random.uniform((N_b,1), lb[0], ub[0], dtype=DTYPE)
x_b = lb[1] + (ub[1] - lb[1]) * tf.keras.backend.random_bernoulli((N_b,1), 0.5, dtype=DTYPE)
X_b = tf.concat([t_b, x_b], axis=1)
u_b = Eikonal_u_b(t_b, x_b)
# Collocation points
t_r = tf.random.uniform((N_r,1), lb[0], ub[0], dtype=DTYPE)
x_r = tf.random.uniform((N_r,1), lb[1], ub[1], dtype=DTYPE)
X_r = tf.concat([t_r, x_r], axis=1)
# Collect boundary and inital data in lists
X_data = [X_0,X_b]
u_data = [u_0,u_b]
Now, we derive a solver for the Eikonal equation from the PINNSolver
class. Since the Eikonal equation does not depend on second-order derivatives, we implement a new method get_r
which avoids the computation of second derivatives.
class EikonalPINNSolver(PINNSolver):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def fun_r(self, t, x, u, u_t, u_x, u_xx):
"""Residual of the PDE"""
return -u_t + tf.abs(u_x) - 1.
def get_r(self):
"""We update get_r since the Eikonal equation is a first-order equation.
Therefore, it is not necessary to compute second derivatives."""
with tf.GradientTape(persistent=True) as tape:
# Watch variables representing t and x during this GradientTape
tape.watch(self.t)
tape.watch(self.x)
# Compute current values u(t,x)
u = self.model(tf.stack([self.t[:,0], self.x[:,0]], axis=1))
u_x = tape.gradient(u, self.x)
u_t = tape.gradient(u, self.t)
del tape
return self.fun_r(self.t, self.x, u, u_t, u_x, None)
The neural network model chosen for this particular problem can be simpler. We decided to use only two hidden layers with 20 neurons in each, resulting in $501$ unknown parameters (first hidden layer: $2 \cdot 20 + 20 = 60$; one intermediate layer: $20 \cdot 20 + 20 = 420$; output layer: $20 \cdot 1 + 1 = 21$). To account for the lack of smoothness of the solution, we choose a non-differentiable activation function, although the hyperbolic tangent function seems to be able to approximate the kinks in the solution sufficiently well. Here, we decided to use the \emph{leaky rectified linear unit (leaky ReLU)} activation function
$$ \begin{align*} \sigma(z) = \begin{cases} z &\text{ if } z \ge 0,\\ 0.1 \, z &\text{ otherwise}, \end{cases} \end{align*} $$which displays a non-vanishing gradient when the unit is not active, i.e., when $z < 0$.
# Initialize model
model = PINN_NeuralNet(lb, ub, num_hidden_layers=2,
activation=tf.keras.layers.LeakyReLU(alpha=0.1),
kernel_initializer='he_normal')
model.build(input_shape=(None,2))
# Initilize PINN solver
eikonalSolver = EikonalPINNSolver(model, X_r)
Start training (take approximately 40 seconds).
# Choose step sizes aka learning rate
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,7000],[1e-1,1e-2,1e-3])
# Solve with Adam optimizer
optim = tf.keras.optimizers.Adam(learning_rate=lr)
# Start timer
t0 = time()
eikonalSolver.solve_with_TFoptimizer(optim, X_data, u_data, N=10001)
# Print computation time
print('\nComputation time: {} seconds'.format(time()-t0))
Plot the results.
eikonalSolver.plot_solution();
#plt.savefig('Eikonal_Solution.pdf', bbox_inches='tight', dpi=300)
eikonalSolver.plot_loss_history();
In this section, we want to demonstrate how the PINN approach could be used to solve partial differential equations with unknown parameters $\lambda$. To be more precise, we consider the parabolic Eikonal equation
$$ \begin{aligned} - \partial_t u + \sup_{|c|\le1} c \cdot \nabla u = -\partial_t u + |\nabla u| &= \lambda^{-1}\\ u(T,x) &= 0\\ u(t,-1) = u(t,1) &= 0 \end{aligned} $$with unknown parameter $\lambda$. The explicit solution is $u^*(t,x) = \lambda^{-1} \min\{1-t, 1-|x|\}$.
lambd_star = 3.
def u_expl(t, x, lambd_star):
"""Explicit solution of the parametric Eikonal equation."""
y = 1./lambd_star
return y * tf.math.minimum(1-t, 1-tf.abs(x))
Next, we draw $N_d = 500$ uniformly distributed measurements of the exact solution.
N_d = 500
noise = 0.0
# Draw points with measurements randomly
t_d = tf.random.uniform((N_d,1), lb[0], ub[0], dtype=DTYPE)
x_d = tf.random.uniform((N_d,1), lb[1], ub[1], dtype=DTYPE)
X_d = tf.concat([t_d, x_d], axis=1)
u_d = u_expl(t_d, x_d, lambd_star)
u_d += noise * tf.random.normal(u_d.shape, dtype=DTYPE)
# Copy original data
X_param = X_data
u_param = u_data
Since both the boundary and initial time data are of Dirichlet type, we may handle the measured data exactly like $X_0$ and $X_b$.
Thus, we can simply append $X_d$ and $u_d$ to Xdata
and udata
.
Note that the approach illustrated here is slightly different from the one introduced in (Raissi et al., 2017 (Part II)) which takes only measurement data into account.
X_param.append(X_d)
u_param.append(u_d)
Next, we derive a new network class which takes the additional parameter $\lambda$ into account. Note that this parameter has to be part of the model in order to be learnt during training.
class PINNIdentificationNet(PINN_NeuralNet):
def __init__(self, *args, **kwargs):
# Call init of base class
super().__init__(*args,**kwargs)
# Initialize variable for lambda
self.lambd = tf.Variable(1.0, trainable=True, dtype=DTYPE)
self.lambd_list = []
Now, we derive a new solver class which only updates the evaluation of the residual fun_r
which now incorporates the $\lambda$-dependency.
In addition, we modify the callback
function to store the iterates of $\lambda$ in a list lambd_list
as well.
class EikonalPINNIdentification(EikonalPINNSolver):
def fun_r(self, t, x, u, u_t, u_x, u_xx):
"""Residual of the PDE"""
return -u_t + tf.abs(u_x) - 1./self.model.lambd
def callback(self, xr=None):
lambd = self.model.lambd.numpy()
self.model.lambd_list.append(lambd)
if self.iter % 50 == 0:
print('It {:05d}: loss = {:10.8e} lambda = {:10.8e}'.format(self.iter, self.current_loss, lambd))
self.hist.append(self.current_loss)
self.iter += 1
def plot_loss_and_param(self, axs=None):
if axs:
ax1, ax2 = axs
self.plot_loss_history(ax1)
else:
ax1 = self.plot_loss_history()
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
color = 'tab:blue'
ax2.tick_params(axis='y', labelcolor=color)
ax2.plot(range(len(self.hist)), self.model.lambd_list,'-',color=color)
ax2.set_ylabel('$\\lambda^{n_{epoch}}$', color=color)
return (ax1,ax2)
Finally, we set up the model consisting of only two hidden layers employing the Leaky ReLU function with slope parameter $\alpha = 0.1$, i.e.,
$$ \sigma(z) = \begin{cases} x & \text{ if } x \ge 0\\ \alpha \, x & \text{ otherwise.} \end{cases} $$The training for $n_{epochs} = 10000$ epochs takes around 45 seconds.
# Initialize model
model = PINNIdentificationNet(lb, ub, num_hidden_layers=2,
activation=tf.keras.layers.LeakyReLU(alpha=0.1),
kernel_initializer='he_normal')
model.build(input_shape=(None,2))
# Initilize solver
eikonalIdentification = EikonalPINNIdentification(model, X_r)
# Choose step sizes aka learning rate
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,7000],[1e-1,1e-2,1e-3])
# Solve with Adam optimizer
optim = tf.keras.optimizers.Adam(learning_rate=lr)
# Start timer
t0 = time()
eikonalIdentification.solve_with_TFoptimizer(optim, X_param, u_param, N=10001)
# Print computation time
print('\nComputation time: {} seconds'.format(time()-t0))
Plot solution $u_\theta(t,x)$ and evolution of loss values $\phi^{n_\text{epoch}}$ and estimated parameters $\lambda^{n_\text{epoch}}$..
ax = eikonalIdentification.plot_solution()
#plt.savefig('Eikonal_PI_Solution.pdf', bbox_inches='tight', dpi=300)
axs = eikonalIdentification.plot_loss_and_param()
#plt.savefig('Eikonal_PI_LossEvolution.pdf', bbox_inches='tight', dpi=300)
Finally, we compute the relative error of the identified parameter $\lambda$.
lambd_rel_error = np.abs((eikonalIdentification.model.lambd.numpy()-lambd_star)/lambd_star)
print('Relative error of lambda ', lambd_rel_error)
The next code cell performs the previous training 5 times in order to give a more reliable picture of the convergence since the weight matrices are initialized randomly at each run (takes about 4 minutes).
lambd_hist = []
loss_hist = []
for i in range(5):
print('{:s}\nStart of iteration {:d}\n{:s}'.format(50*'-',i,50*'-'))
# Initialize model
model = PINNIdentificationNet(lb, ub, num_hidden_layers=2,
activation=tf.keras.layers.LeakyReLU(alpha=0.1),
kernel_initializer='he_normal')
model.build(input_shape=(None,2))
# Initilize solver
eikonalIdentification = EikonalPINNIdentification(model, X_r)
# Choose step sizes aka learning rate
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,7000],[1e-1,1e-2,1e-3])
N=10001
# Solve with Adam optimizer
optim = tf.keras.optimizers.Adam(learning_rate=lr)
eikonalIdentification.solve_with_TFoptimizer(optim, X_param, u_param, N=N)
# Store evolution of lambdas
lambd_hist.append(model.lambd_list)
# Store evolution of losses
loss_hist.append(eikonalIdentification.hist)
Next, we generate a table printing the mean and standard deviations of the identified parameter $\lambda$ obtained for the previous runs.
print(' i Mean of lambda Std. of lambda')
for i in [1000,2000,3000,4000,5000,6000,7000,8000,9000,10000]:
xi = np.array([ x[i] for x in lambd_hist])
print('{:05d} {:6.4e} {:6.4e}'.format(i, xi.mean(), xi.std()))
Next, we plot the five evolutions of $\lambda$ (dark gray), its mean (solid blue) and one standard deviation (shaded blue) together with the true value of $\lambda$ (dashed blue).
fig = plt.figure()
ax = fig.add_subplot(111)
color = 'tab:blue'
Lambd = np.stack(lambd_hist)
lmean = Lambd.mean(axis=0)
lstd = Lambd.std(axis=0)
Lambd_RelError = np.abs((Lambd-lambd_star)/lambd_star)
lrange=range(len(lmean))
for i in range(len(lambd_hist)):
ax.plot(lrange, lambd_hist[i],'-',color='black', alpha=0.5)
ax.plot(lrange, lmean,'-',color=color)
ax.plot(lrange, lambd_star*np.ones_like(lmean),'--',color=color)
ax.fill_between(lrange,lmean-lstd,lmean+lstd, alpha=0.2)
ax.set_ylabel('$\\lambda^{n_{epoch}}$')
ax.set_xlabel('$n_{epoch}$')
ax.set_ylim([2.8,3.2])
#plt.savefig('Eikonal_PI_Evolution.pdf', bbox_inches='tight', dpi=300)
Finally, we plot the mean relative error of $\lambda$.
fig = plt.figure()
ax = fig.add_subplot(111)
ax.semilogy(lrange,Lambd_RelError.mean(axis=0))
ax.fill_between(lrange,Lambd_RelError.mean(axis=0)-Lambd_RelError.std(axis=0),
Lambd_RelError.mean(axis=0)+Lambd_RelError.std(axis=0), alpha=0.2)
ax.set_xlabel('$n_{epoch}$')
ax.set_ylabel('$e_{\\lambda}^{rel}$')
ax.set_title('Mean relative error of $\\lambda$');