import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm
from fixed_grid import Euler, Midpoint, RK4
from dopri5 import Dopri5Solver
from misc import _decreasing, _check_inputs, _flatten, _flatten_convert_none_to_zeros
ODEINT и odeint_adjoint
Функции ODEINT и odeint_adjoint применяют численные методы для вычисления выхода модели. Причем в odeint_adjoint используются сопряженные переменные, поэтому объем его используемой памяти - константа O(1), в отличие от ODEINT.
def ODEINT(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None):
"""Integrate a system of ordinary differential equations.
Solves the initial value problem for a non-stiff system of first order ODEs:
```
dy/dt = func(t, y), y(t[0]) = y0
```
where y is a Tensor of any shape.
Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
Args:
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
`t` into a Tensor of state derivatives with respect to time.
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
have any floating point or complex dtype.
t: 1-D Tensor holding a sequence of time points for which to solve for
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype. Converted to a Tensor with float64 dtype.
rtol: optional float64 Tensor specifying an upper bound on relative error,
per element of `y`.
atol: optional float64 Tensor specifying an upper bound on absolute error,
per element of `y`.
method: optional string indicating the integration method to use.
options: optional dict of configuring options for the indicated integration
method. Can only be provided if a `method` is explicitly set.
name: Optional name for this operation.
Returns:
y: Tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
`t`, with the initial value `y0` being the first element along the first
dimension.
Raises:
ValueError: if an invalid `method` is provided.
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
an invalid dtype.
"""
tensor_input, func, y0, t = _check_inputs(func, y0, t)
if options is None:
options = {}
elif method is None:
raise ValueError('cannot supply `options` without specifying `method`')
if method is None:
method = 'dopri5' # ЧМЫ
solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options)
solution = solver.integrate(t)
if tensor_input:
solution = solution[0]
return solution
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
assert len(args) >= 8, 'Internal error: all arguments required.'
y0, func, t, flat_params, rtol, atol, method, options = \
args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options
with torch.no_grad():
ans = ODEINT(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
ctx.save_for_backward(t, flat_params, *ans)
return ans
@staticmethod
def backward(ctx, *grad_output):
t, flat_params, *ans = ctx.saved_tensors
ans = tuple(ans)
func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
n_tensors = len(ans)
f_params = tuple(func.parameters())
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params.
with torch.set_grad_enabled(True):
t = t.to(y[0].device).detach().requires_grad_(True)
y = tuple(y_.detach().requires_grad_(True) for y_ in y)
func_eval = func(t, y)
vjp_t, *vjp_y_and_params = torch.autograd.grad(
func_eval, (t,) + y + f_params,
tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
)
vjp_y = vjp_y_and_params[:n_tensors]
vjp_params = vjp_y_and_params[n_tensors:]
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)
if len(f_params) == 0:
vjp_params = torch.tensor(0.).to(vjp_y[0])
return (*func_eval, *vjp_y, vjp_t, vjp_params)
T = ans[0].shape[0]
with torch.no_grad():
adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
adj_params = torch.zeros_like(flat_params)
adj_time = torch.tensor(0.).to(t)
time_vjps = []
for i in range(T - 1, 0, -1):
ans_i = tuple(ans_[i] for ans_ in ans)
grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
func_i = func(t[i], ans_i)
# Compute the effect of moving the current time measurement point.
dLd_cur_t = sum(
torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
)
adj_time = adj_time - dLd_cur_t
time_vjps.append(dLd_cur_t)
# Run the augmented system backwards in time.
if adj_params.numel() == 0:
adj_params = torch.tensor(0.).to(adj_y[0])
aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
aug_ans = ODEINT(
augmented_dynamics, aug_y0,
torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
)
# Unpack aug_ans.
adj_y = aug_ans[n_tensors:2 * n_tensors]
adj_time = aug_ans[2 * n_tensors]
adj_params = aug_ans[2 * n_tensors + 1]
adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
if len(adj_time) > 0: adj_time = adj_time[1]
if len(adj_params) > 0: adj_params = adj_params[1]
adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
del aug_y0, aug_ans
time_vjps.append(adj_time)
time_vjps = torch.cat(time_vjps[::-1])
return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None)
def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None):
# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if not isinstance(func, nn.Module):
raise ValueError('func is required to be an instance of nn.Module.')
tensor_input = False
if torch.is_tensor(y0):
class TupleFunc(nn.Module):
def __init__(self, base_func):
super(TupleFunc, self).__init__()
self.base_func = base_func
def forward(self, t, y):
return (self.base_func(t, y[0]),)
tensor_input = True
y0 = (y0,)
func = TupleFunc(func)
flat_params = _flatten(func.parameters())
ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)
if tensor_input:
ys = ys[0]
return ys
Вспомогательные функции
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
# загружает данные
if data_aug:
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
train_loader = DataLoader(
datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=0, drop_last=True
)
train_eval_loader = DataLoader(
datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True
)
test_loader = DataLoader(
datasets.MNIST(root='data/mnist', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True
)
return train_loader, test_loader, train_eval_loader #loader
def inf_generator(iterable):
"""Allows training with DataLoaders in a single infinite loop:
for i, (x, y) in enumerate(inf_generator(train_loader)):
"""
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()
def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
# реализует затухание lr
initial_learning_rate = LR * batch_size / batch_denom
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(itr):
lt = [itr < b for b in boundaries] + [True]
i = np.argmax(lt)
return vals[i]
return learning_rate_fn
def one_hot(x, K):
#one hot кодирование
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
def accuracy(model, dataset_loader):
total_correct = 0
for x, y in dataset_loader:
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)
target_class = np.argmax(y, axis=1)
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
return (total_correct / len(dataset_loader.dataset)) * 100 # просто точность
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) # число параметров модели,что показывает её сложность
def makedirs(dirname): # создаёт дирректорию
if not os.path.exists(dirname):
os.makedirs(dirname)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim):
"""нормализация: https://arxiv.org/pdf/1803.08494.pdf"""
return nn.GroupNorm(min(32, dim), dim)
Res блоки задают архитектуру остаточной сети
class ResBlock(nn.Module):
"""Блок для ResNet"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.conv2 = conv3x3(planes, planes)
def forward(self, x):
shortcut = x
out = self.relu(self.norm1(x))
if self.downsample is not None:
shortcut = self.downsample(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(out)
return out + shortcut # F(x) + x
ODE блоки Задают архитектуру ODE сети
class ConcatConv2d(nn.Module):
"""Особая свёртка"""
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class ODEfunc(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0 # number of function evaluations
# объем вычислений, зависит от размера сетки в численном методе
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out) # t участвует в свертках, тех самых concat conv
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc # класс ODEfunc описан выше
self.integration_time = torch.tensor([0, 1]).float()
def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol= TOL, atol = TOL, method = METHOD)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
TOL(tolernce), параметр для функций ODEINT и odeint_adjoint. Этот параметр позволяет адаптировать модель, изменяя его, можно получить более быструю но менее точную модель
METHOD, параметр для функций ODEINT и odeint_adjoint. Этот параметр определяет какой именно численный метод будет использоваться. От него зависит точность и время работы метода.
Параметры эксперимента
Чтобы удостовериться в влиянии TOL можно его увеличить и уменьшить и посмотреть соответствеено результат.
Также можно поменять параметр METHOD, выбрав его из SOLVERS.
Можно сравнить ODE Net с Res Net, посмотрев на число параметров, на скорость сходимости, для этого следует переключать параметр is_odenet.
Для того чтобы Проверить утверждение о крнстантной памяти при использовании odeint_adjoint, нужно изменить параметр odeint на odeint_adjoint и значительно увеличить BTCH_SZ например до 200. При таком размере BTCH_SZ и спользовании метода odeint, ноутбук скорее всего упадёт.
Если хочется провести точный эксперимент такой же как в статье то можно выставить число эпох (NEPOCHS) на 128, но это будет довольно долго.
TOL = 1e-3 # испоьзуется в odeint
SOLVERS = {'dopri5': Dopri5Solver, 'euler': Euler, 'midpoint': Midpoint, 'rk4': RK4} # Численные методы
METHOD = 'rk4'# Выбрать из 'dopri5' 'euler' 'midpoint' 'rk4'
LR = 0.1 # используется в learning_rate_with_decay
odeint = ODEINT# ODEINT или odeint_adjoint
is_odenet = True # тут может быть False, тогда будет в эксперименте участвовать resnet а не ODE-Net
BTCH_SZ = 50
NEPOCHS = 5
makedirs("./experiment")
device = 'cpu'
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]
feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
инициализация модели
# сама модель
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
# Для оптимизации
criterion = nn.CrossEntropyLoss().to(device)
# поучаем loader-ы
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
data_aug=True, batch_size = BTCH_SZ, test_batch_size=1000)
# Получаем бесконечный итератор
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)
# Уменьшение lr от итерации
lr_fn = learning_rate_with_decay(
batch_size = BTCH_SZ, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
decay_rates=[1, 0.1, 0.01, 0.001]
)
# Оптимизатор
optimizer = torch.optim.SGD(model.parameters(), lr= LR, momentum=0.9)
# Инициализируем 3 различных RunningAverageMeter
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Processing... Done!
число параметров
можно заметить что у ResNet их больше
print(count_parameters(nn.Sequential(*downsampling_layers)),
count_parameters(nn.Sequential(*feature_layers)), count_parameters(nn.Sequential(*fc_layers)))
print('Number of parameters: {}'.format(count_parameters(model)) )
132096 75392 778 Number of parameters: 208266
запуск эксперимента
# Главный цикл
for itr in tqdm(range(1, NEPOCHS * batches_per_epoch + 1)):
for param_group in optimizer.param_groups:
param_group['lr'] = lr_fn(itr)
optimizer.zero_grad()
x, y = data_gen.__next__()
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
if is_odenet:
nfe_forward = feature_layers[0].nfe
feature_layers[0].nfe = 0
loss.backward()
optimizer.step()
if is_odenet:
nfe_backward = feature_layers[0].nfe
feature_layers[0].nfe = 0
batch_time_meter.update(time.time() - end)
if is_odenet:
f_nfe_meter.update(nfe_forward)
b_nfe_meter.update(nfe_backward)
end = time.time()
if itr % batches_per_epoch == 0:
with torch.no_grad():
val_acc = accuracy(model, test_loader)
print("epoch: ", itr/ batches_per_epoch,"te_err (%) : ", 100 - val_acc)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch: 1.0 te_err (%) : 2.969999999999999 epoch: 2.0 te_err (%) : 0.9399999999999977 epoch: 3.0 te_err (%) : 1.2199999999999989 epoch: 4.0 te_err (%) : 0.7600000000000051 epoch: 5.0 te_err (%) : 1.0100000000000051