Данный ноутбук содержит код эксперимента по исследованию возможности модели нейродифференциального уравнения явным образом контролировать trade-off между численной точностью и вычислительными затратами.
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchdiffeq import odeint_adjoint as odeint
from utils import norm, Flatten, get_mnist_loaders, one_hot, ConcatConv2d
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
class ODEfunc(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class ODEBlock(nn.Module):
def __init__(self, odefunc, tol=1e-3, method=None):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
self.tol = tol
self.method = method
def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol=self.tol,
atol=self.tol, method=self.method)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
Загрузим веса заранее обученной модели ODE-Net (во время обучения использовалась максимально допустимая абсолютная ошибка численного метода $tol=10^{-3}$):
checkpoint = torch.load('ODEnet_mnist.pth')
batch_size = test_batch_size = 1000
data_aug = False
all_acc = []
all_times = []
nfes = []
tols = [1e-4, 1e-3, 1e-2, 1e-1, 1]
for tol in tols:
feature_layers = [ODEBlock(ODEfunc(64), tol=tol)]
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_batch_size)
with torch.no_grad():
times = []
total_correct = 0
for x, y in test_loader:
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)
target_class = np.argmax(y, axis=1)
start = time.time()
preds = model(x)
times.append(time.time() - start)
predicted_class = np.argmax(preds.cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
accuracy = total_correct / len(test_loader.dataset)
nfe = feature_layers[0].nfe
all_acc.append(accuracy)
nfes.append(nfe)
all_times.append(times)
print('Tol={0}, accuracy={1}, mean_time={2:0.2f}, nfe={3}'.format(tol, accuracy, np.mean(times), nfe))
Tol=0.0001, accuracy=0.9961, mean_time=4.60, nfe=320 Tol=0.001, accuracy=0.9961, mean_time=3.76, nfe=260 Tol=0.01, accuracy=0.996, mean_time=2.93, nfe=200 Tol=0.1, accuracy=0.9961, mean_time=2.13, nfe=140 Tol=1, accuracy=0.9956, mean_time=2.13, nfe=140