import matplotlib.pyplot as plt
%matplotlib inline
from tsit5 import Tsit5Solver
from dopri5 import Dopri5Solver
from fixed_grid import Euler, Midpoint, RK4
from fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from adams import VariableCoefficientAdamsBashforth
from misc import _check_inputs
from IPython.display import clear_output
from tqdm import tqdm
import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import animation
import pickle as pk
from JSAnimation import IPython_display
from adjoint import odeint_adjoint as ODEINT_ADJOINT
from odeint import odeint as ODEINT
from IPython.display import HTML
Необходимо приблизить решение обыкновенного дифференциального уравнения первого порядка:
$$
\frac{dz(t)}{dt} = f(z(t),A,t), \, z(t_0) = y_0
$$
$$
f(z(t),A,t) = Az(t)^3, \, A = \begin{vmatrix}
-0.1 & 2 \\
-2 & -0.1
\end{vmatrix}
$$
При этом по условию задачи, параметры и вид динамики $f(z(t),A,t)$ неизвестны. В таком случае динамика может быть смоделирована нейронной сетью, параметры которой необходимо обучить.
Общий подход к решению: C помощью одного из ODEsolver'ов получаем пиближение $z(t)$: $\hat{z}(t) = ODEsolver()$, затем считается $Loss(\hat{z})$ и делается градиентный шаг в сторону оптимума. В данном случае в качестве функции потерь был выбран $MAE$.
Т.к динамика f задействована явно, при получении приближения с помощью solver'а, то можно считать градиенты функции потерь в явном виде. Существует и другой способ, через решение сопряженной системы: $$ \frac{da(t)}{dt} = -a(t)\frac{\partial f(z(t),\theta,t)}{\partial z(t)}, \, a(T) = \frac{\partial L}{\partial y} $$ Тогда: $$ \frac{dL}{d\theta} = \int_{t_0}^{T} a(t) \frac{\partial f(z(t),\theta,t)}{\partial \theta} dt$$
Такой подход обладает тем преимуществом, что используется O(1) по памяти.
Предлагается сравнить работу данных методов, при различных значениях параметров batch_size, batch_time, learning_rate. За выбор метода отвечает аргумент adjoint.
class Arguments():
def __init__(self, method=None, data_size=None, batch_time=None, batch_size=None,\
niters=None, test_freq=None, viz=None, gpu=None, adjoint=None):
self.method = 'dopri5' if method is None else method #choices are 'dopri5'or 'adams'
self.data_size = 1000 if data_size is None else data_size #synthetic dataset size
self.batch_time = 10 if batch_time is None else batch_time #size of time period of ODE in batch
self.batch_size = 20 if batch_size is None else batch_size #number of ODEs in batch
self.niters = 2000 if niters is None else niters #training iterations number
self.test_freq = 20 if test_freq is None else test_freq #frequency of visualization
self.viz = False if viz is None else viz #if True then visualize
self.gpu = 0 if gpu is None else gpu #device index to select
self.adjoint = False if adjoint is None else adjoint #if True then solves through
# adjoint system
def get_batch():
s = torch.from_numpy(np.random.choice(np.arange(args.data_size\
- args.batch_time, dtype=np.int64), args.batch_size, replace=False))
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
return batch_y0, batch_t, batch_y
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
makedirs('png')
makedirs('pickle')
def visualize(true_y, pred_y, odefunc, itr, args):
if args.viz:
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
ax_traj.cla()
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t.numpy(), true_y.numpy()[:, 0, 0], t.numpy(), true_y.numpy()[:, 0, 1], 'g-')
ax_traj.plot(t.numpy(), pred_y.numpy()[:, 0, 0], '--', t.numpy(), pred_y.numpy()[:, 0, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.set_ylim(-2, 2)
ax_traj.legend()
ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y.numpy()[:, 0, 0], true_y.numpy()[:, 0, 1], 'g-')
ax_phase.plot(pred_y.numpy()[:, 0, 0], pred_y.numpy()[:, 0, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)
ax_vecfield.cla()
ax_vecfield.set_title('Learned Vector Field')
ax_vecfield.set_xlabel('x')
ax_vecfield.set_ylabel('y')
y, x = np.mgrid[-2:2:21j, -2:2:21j]
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy()
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)
fig.tight_layout()
if args.adjoint:
plt.savefig('png/{:03d}_adj'.format(itr), bbox_inches='tight')
else:
plt.savefig('png/{:03d}_adj'.format(itr), bbox_inches='tight')
plt.draw()
plt.show()
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
Настоящая функция, приблизить которую мы хотим:
class Lambda(nn.Module):
def forward(self, t, y):
return torch.mm(y**3, true_A)
Наше нейродиффиренциальное уравнение:
class ODEFunc(nn.Module):
def __init__(self):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)
for m in self.net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=0.1)
nn.init.constant_(m.bias, val=0)
def forward(self, t, y):
return self.net(y**3)
Задаем значения аргументов:
args = Arguments()
args.viz = True
args.test_freq = 20
args.adjoint = False
args.batch_size = 100
args.lr=1e-3
args.niters=1000
Генерируем выборку:
if args.adjoint:
odeint = ODEINT
else:
odeint = ODEINT_ADJOINT
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t, method='dopri5')
Запускаем процесс обучения, каждые args.test_freq итераций, считаем функцию потерь и визуализируем приближение:
ii = 0
func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr=args.lr)
end = time.time()
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)
iterator = tqdm(range(1, args.niters + 1))
for itr in iterator:
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
pred_y = odeint(func, batch_y0, batch_t)
loss = torch.mean(torch.abs(pred_y - batch_y))
loss.backward()
optimizer.step()
time_meter.update(time.time() - end)
loss_meter.update(loss.item())
if itr % args.test_freq == 0:
with torch.no_grad():
clear_output(wait=True)
pred_y = odeint(func, true_y0, t)
loss = torch.mean(torch.abs(pred_y - true_y))
iterator.set_description('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
visualize(true_y, pred_y, func, ii, args)
ii += 1
end = time.time()
clear_output()
plt.close()
Визуализация процесса обучения, на первом графике изображены траектории по $x$ и $y$ от времени($t$), на следующем изображено полученное приближение параметрической кривой в фазоом пространстве:
fig = plt.figure(figsize=(14,5))
im = plt.imshow(plt.imread('png/{:03d}.png'.format(0)))
def updatefig(j):
return im.set_data(plt.imread('png/{:03d}.png'.format(j))),
video = animation.FuncAnimation(fig, updatefig, frames=args.niters//args.test_freq, interval=700, blit=False)
plt.close()
HTML(video.to_html5_video())