In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
from tsit5 import Tsit5Solver
from dopri5 import Dopri5Solver
from fixed_grid import Euler, Midpoint, RK4
from fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from adams import VariableCoefficientAdamsBashforth
from misc import _check_inputs
from IPython.display import clear_output
from tqdm import tqdm
import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import animation
import pickle as pk
from JSAnimation import IPython_display
from adjoint import odeint_adjoint as ODEINT_ADJOINT
from odeint import odeint as ODEINT
from IPython.display import HTML

Необходимо приблизить решение обыкновенного дифференциального уравнения первого порядка:
$$ \frac{dz(t)}{dt} = f(z(t),A,t), \, z(t_0) = y_0 $$ $$ f(z(t),A,t) = Az(t)^3, \, A = \begin{vmatrix} -0.1 & 2 \\ -2 & -0.1 \end{vmatrix} $$

При этом по условию задачи, параметры и вид динамики $f(z(t),A,t)$ неизвестны. В таком случае динамика может быть смоделирована нейронной сетью, параметры которой необходимо обучить.

Общий подход к решению: C помощью одного из ODEsolver'ов получаем пиближение $z(t)$: $\hat{z}(t) = ODEsolver()$, затем считается $Loss(\hat{z})$ и делается градиентный шаг в сторону оптимума. В данном случае в качестве функции потерь был выбран $MAE$.

Т.к динамика f задействована явно, при получении приближения с помощью solver'а, то можно считать градиенты функции потерь в явном виде. Существует и другой способ, через решение сопряженной системы: $$ \frac{da(t)}{dt} = -a(t)\frac{\partial f(z(t),\theta,t)}{\partial z(t)}, \, a(T) = \frac{\partial L}{\partial y} $$ Тогда: $$ \frac{dL}{d\theta} = \int_{t_0}^{T} a(t) \frac{\partial f(z(t),\theta,t)}{\partial \theta} dt$$

Такой подход обладает тем преимуществом, что используется O(1) по памяти.

Предлагается сравнить работу данных методов, при различных значениях параметров batch_size, batch_time, learning_rate. За выбор метода отвечает аргумент adjoint.

In [2]:
class Arguments():
    
    def __init__(self, method=None, data_size=None, batch_time=None, batch_size=None,\
                 niters=None, test_freq=None, viz=None, gpu=None, adjoint=None):
        
        self.method = 'dopri5' if method is None else method        #choices are 'dopri5'or 'adams'
        self.data_size = 1000 if data_size is None else data_size   #synthetic dataset size
        self.batch_time = 10 if batch_time is None else batch_time  #size of time period of ODE in batch 
        self.batch_size = 20 if batch_size is None else batch_size  #number of ODEs in batch
        self.niters = 2000 if niters is None else niters            #training iterations number
        self.test_freq = 20 if test_freq is None else test_freq     #frequency of visualization
        self.viz = False if viz is None else viz                    #if True then visualize
        self.gpu = 0 if gpu is None else gpu                        #device index to select
        self.adjoint = False if adjoint is None else adjoint        #if True then solves through
                                                                    #                   adjoint system

    
In [3]:
def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(args.data_size\
                                    - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0, batch_t, batch_y


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)

makedirs('png')
makedirs('pickle')

def visualize(true_y, pred_y, odefunc, itr, args):

    if args.viz:
        
        fig = plt.figure(figsize=(12, 4), facecolor='white')
        ax_traj = fig.add_subplot(131, frameon=False)
        ax_phase = fig.add_subplot(132, frameon=False)
        ax_vecfield = fig.add_subplot(133, frameon=False)

        ax_traj.cla()
        ax_traj.set_title('Trajectories')
        ax_traj.set_xlabel('t')
        ax_traj.set_ylabel('x,y')
        ax_traj.plot(t.numpy(), true_y.numpy()[:, 0, 0], t.numpy(), true_y.numpy()[:, 0, 1], 'g-')
        ax_traj.plot(t.numpy(), pred_y.numpy()[:, 0, 0], '--', t.numpy(), pred_y.numpy()[:, 0, 1], 'b--')
        ax_traj.set_xlim(t.min(), t.max())
        ax_traj.set_ylim(-2, 2)
        ax_traj.legend()

        ax_phase.cla()
        ax_phase.set_title('Phase Portrait')
        ax_phase.set_xlabel('x')
        ax_phase.set_ylabel('y')
        ax_phase.plot(true_y.numpy()[:, 0, 0], true_y.numpy()[:, 0, 1], 'g-')
        ax_phase.plot(pred_y.numpy()[:, 0, 0], pred_y.numpy()[:, 0, 1], 'b--')
        ax_phase.set_xlim(-2, 2)
        ax_phase.set_ylim(-2, 2)

        ax_vecfield.cla()
        ax_vecfield.set_title('Learned Vector Field')
        ax_vecfield.set_xlabel('x')
        ax_vecfield.set_ylabel('y')

        y, x = np.mgrid[-2:2:21j, -2:2:21j]
        dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy()
        mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
        dydt = (dydt / mag)
        dydt = dydt.reshape(21, 21, 2)

        ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
        ax_vecfield.set_xlim(-2, 2)
        ax_vecfield.set_ylim(-2, 2)

        fig.tight_layout()
        if args.adjoint:
            plt.savefig('png/{:03d}_adj'.format(itr), bbox_inches='tight')
        else:
            plt.savefig('png/{:03d}_adj'.format(itr), bbox_inches='tight')
        plt.draw()
        plt.show()


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

Настоящая функция, приблизить которую мы хотим:

In [42]:
class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)

Наше нейродиффиренциальное уравнение:

In [43]:
class ODEFunc(nn.Module):
    
    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y**3)

Случай прямой оптимизации через внутреннее представление solver'а.

Задаем значения аргументов:

In [37]:
args = Arguments()
args.viz = True
args.test_freq = 20
args.adjoint = False
args.batch_size = 100
args.lr=1e-3
args.niters=1000

Генерируем выборку:

In [30]:
if args.adjoint:
    odeint = ODEINT
else:
    odeint = ODEINT_ADJOINT
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

Запускаем процесс обучения, каждые args.test_freq итераций, считаем функцию потерь и визуализируем приближение:

In [31]:
ii = 0

func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr=args.lr)
end = time.time()

time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

iterator =  tqdm(range(1, args.niters + 1))
for itr in iterator:
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint(func, batch_y0, batch_t)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()

    time_meter.update(time.time() - end)
    loss_meter.update(loss.item())

    if itr % args.test_freq == 0:
        with torch.no_grad():
            clear_output(wait=True)
            pred_y = odeint(func, true_y0, t)
            loss = torch.mean(torch.abs(pred_y - true_y))
            iterator.set_description('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
            visualize(true_y, pred_y, func, ii, args)
            ii += 1
            
    
    end = time.time()
clear_output()
plt.close()

Визуализация процесса обучения, на первом графике изображены траектории по $x$ и $y$ от времени($t$), на следующем изображено полученное приближение параметрической кривой в фазоом пространстве:

In [46]:
fig = plt.figure(figsize=(14,5))

im = plt.imshow(plt.imread('png/{:03d}.png'.format(0)))

def updatefig(j):
    return im.set_data(plt.imread('png/{:03d}.png'.format(j))),

video = animation.FuncAnimation(fig, updatefig, frames=args.niters//args.test_freq, interval=700, blit=False)
plt.close()

HTML(video.to_html5_video())
Out[46]:

Случай оптимизации через решение сопряженной системы.

Задаем значения аргументов:

In [49]:
args = Arguments()
args.viz = True
args.test_freq = 20
args.adjoint = True
args.batch_size = 100
args.niters = 1000
args.lr=1e-3

Генерируем выборку:

In [30]:
if args.adjoint:
    odeint = ODEINT
else:
    odeint = ODEINT_ADJOINT
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

Запускаем процесс обучения, каждые args.test_freq итераций, считаем функцию потерь и визуализируем приближение:

In [48]:
ii = 0

func = ODEFunc()
#optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
optimizer = optim.RMSprop(func.parameters(), lr=args.lr)
end = time.time()

time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

iterator =  tqdm(range(1, args.niters + 1))
for itr in iterator:
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint(func, batch_y0, batch_t)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()

    time_meter.update(time.time() - end)
    loss_meter.update(loss.item())

    if itr % args.test_freq == 0:
        with torch.no_grad():
            clear_output(wait=True)
            pred_y = odeint(func, true_y0, t)
            loss = torch.mean(torch.abs(pred_y - true_y))
            iterator.set_description('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
            visualize(true_y, pred_y, func, ii, args)
            ii += 1
            
    
    end = time.time()
clear_output()
plt.close()

Визуализация процесса обучения, на первом графике изображены траектории по $x$ и $y$ от времени($t$), на следующем изображено полученное приближение параметрической кривой в фазоом пространстве:

In [50]:
fig = plt.figure(figsize=(14,5))

im = plt.imshow(plt.imread('png/{:03d}_adj.png'.format(0)))

def updatefig(j):
    return im.set_data(plt.imread('png/{:03d}_adj.png'.format(j))),

video = animation.FuncAnimation(fig, updatefig, frames=args.niters//args.test_freq, interval=700, blit=False)
plt.close()

HTML(video.to_html5_video())
Out[50]:

Замечания:

  • Можно заметить, что метод основаный на решении сопряженной системы дифференциальных уравнений, работает медленнее.
  • Оба метода успешно справляются с поставленной задачей.
  • При увеличении значений batch_size и batch_time обучение становится более стабильным, но замедляется.
  • То же самое при уменьшении параметра lr, отвечающего за размер шага градиентного спуска.