Use the proposed heatmap method as described in [1]
[1] M. Blendowski et al. "How to Learn from Unlabeled Volume Data:
Self-supervised 3D Context Feature Learning." MICCAI. 2019.
from typing import Callable, List, Optional, Tuple, Union
from glob import glob
import math
import os
import random
import sys
gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
from selfsupervised3d import *
Support in-notebook plotting
%matplotlib inline
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
numpy version: 1.17.2 matplotlib version: 3.1.1 pytorch version: 1.5.0 torchvision version: 0.6.0a0+82fd1c8
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.7
Reload packages where content for package development
%load_ext autoreload
%autoreload 2
Check GPU(s)
!nvidia-smi
Fri May 1 14:09:36 2020 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla M40 24GB Off | 00000000:02:00.0 Off | 0 | | N/A 39C P0 58W / 250W | 1401MiB / 22945MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 1 Tesla M40 24GB Off | 00000000:03:00.0 Off | 0 | | N/A 57C P0 147W / 250W | 18818MiB / 22945MiB | 100% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | 0 28995 C ...r/miniconda3/envs/synthtorch/bin/python 1390MiB | | 1 10555 C python3 18807MiB | +-----------------------------------------------------------------------------+
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
Set seeds for better reproducibility. See this note before using multiprocessing.
seed = 1336
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Get the location of the training (and validation) data
train_dir = '/iacl/pg20/jacobr/ixi/subsets/hh/'
t1_dir = os.path.join(train_dir, 't1')
t2_dir = os.path.join(train_dir, 't2')
t1_fns = glob(os.path.join(t1_dir, '*.nii*'))
t2_fns = glob(os.path.join(t2_dir, '*.nii*'))
assert len(t1_fns) == len(t2_fns) and len(t1_fns) != 0
Look at an axial view of the source T1-weighted (T1-w) and target T2-weighted (T2-w) images.
def imshow(x, ax, title, n_rot=3, **kwargs):
ax.imshow(np.rot90(x,n_rot), aspect='equal', cmap='gray', **kwargs)
ax.set_title(title,fontsize=22)
ax.axis('off')
j = 100
t1_ex, t2_ex = nib.load(t1_fns[0]).get_data(), nib.load(t2_fns[0]).get_data()
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(t1_ex[...,j], ax1, 'T1', 1)
imshow(t2_ex[...,j], ax2, 'T2', 1)
x = torch.from_numpy(t1_ex).unsqueeze(0)
(ctr, qry), (dp_goal, hm_goal) = blendowski_patches(x, min_off_inplane=0., max_off_inplane=0.7, throughplane_axis=1)
ctr = ctr.squeeze().cpu().detach().numpy()
qry = qry.squeeze().cpu().detach().numpy()
hm_goal = hm_goal.squeeze()
dx, dy = dp_goal
print(f'dx: {dx:0.3f}, dy: {dy:0.3f}')
dx: -0.427, dy: 0.501
print(ctr.shape, qry.shape)
(3, 42, 42) (3, 42, 42)
j = 12
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,9))
imshow(ctr[1,...], ax1, 'CTR', 0)
imshow(qry[1,...], ax2, 'QRY', 0)
imshow(hm_goal, ax3, 'HM', 0)
Hyperparameters, optimizers, logging, etc.
data_dirs = [t1_dir]
# system setup
load_model = False
# logging setup
log_rate = 10 # print losses every log_rate epochs
version = 'blendowski_v1' # naming scheme of model to load
save_rate = 100 # save models every save_rate epochs
# model, optimizer, loss, and training parameters
valid_split = 0.1
batch_size = 8
n_jobs = 8
n_epochs = 500
stack_dim = 3
input_channels = stack_dim * len(data_dirs)
descriptor_size = 128
use_adam = True
opt_kwargs = dict(lr=1e-3, betas=(0.9,0.99), weight_decay=1e-6) if use_adam else \
dict(lr=5e-3, momentum=0.9)
use_scheduler = True
scheduler_kwargs = dict(step_size=100, gamma=0.5)
def init_fn(worker_id):
random.seed((torch.initial_seed() + worker_id) % (2**32))
np.random.seed((torch.initial_seed() + worker_id) % (2**32))
# setup training and validation dataloaders
dataset = BlendowskiDataset(data_dirs, stack_dim=stack_dim, throughplane_axis=1)
num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size,
worker_init_fn=init_fn, num_workers=n_jobs,
pin_memory=True, collate_fn=blendowski_collate)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size,
worker_init_fn=init_fn, num_workers=n_jobs,
pin_memory=True, collate_fn=blendowski_collate)
print(f'Number of training images: {num_train-split}')
print(f'Number of validation images: {split}')
Number of training images: 121 Number of validation images: 13
embedding_model = D2DConvNet(input_channels=input_channels, descriptor_size=descriptor_size)
decoder_model = HeatNet(descriptor_size=descriptor_size)
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters in embedding model: {num_params(embedding_model)}')
print(f'Number of trainable parameters in decoder model: {num_params(decoder_model)}')
Number of trainable parameters in embedding model: 629184 Number of trainable parameters in decoder model: 111993
if load_model:
embedding_model.load_state_dict(torch.load(f'embedding_model_{version}.pth'))
decoder_model.load_state_dict(torch.load(f'decoder_model_{version}.pth'))
embedding_model.to(device)
decoder_model.to(device)
optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD
embedding_opt = optim_cls(embedding_model.parameters(), **opt_kwargs)
decoder_opt = optim_cls(decoder_model.parameters(), **opt_kwargs)
criterion = nn.MSELoss()
if use_scheduler:
embedding_scheduler = torch.optim.lr_scheduler.StepLR(embedding_opt, **scheduler_kwargs)
decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, **scheduler_kwargs)
train_losses, valid_losses = [], []
n_batches = len(train_loader)
min_off_inplane = np.linspace(0.25, 0.0, n_epochs)
max_off_inplane = np.linspace(0.30, 0.7, n_epochs)
for t in range(1, n_epochs + 1):
# training
t_losses = []
embedding_model.train()
decoder_model.train()
for i, ((ctr, qry), (_, goal)) in enumerate(train_loader):
ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device)
embedding_opt.zero_grad()
decoder_opt.zero_grad()
ctr_f = embedding_model(ctr)
qry_f = embedding_model(qry)
out = decoder_model(ctr_f, qry_f)
loss = criterion(out, goal)
t_losses.append(loss.item())
loss.backward()
embedding_opt.step()
decoder_opt.step()
train_losses.append(t_losses)
# validation
v_losses = []
embedding_model.eval()
decoder_model.eval()
with torch.no_grad():
for i, ((ctr, qry), (_, goal)) in enumerate(valid_loader):
ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device)
ctr_f = embedding_model(ctr)
qry_f = embedding_model(qry)
out = decoder_model(ctr_f, qry_f)
loss = criterion(out, goal)
v_losses.append(loss.item())
valid_losses.append(v_losses)
# expand inplane offset range as per paper
dataset.min_off_inplane = min_off_inplane[t-1]
dataset.max_off_inplane = max_off_inplane[t-1]
# log, step scheduler, and save results from epoch
if not np.all(np.isfinite(t_losses)):
raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
if t % log_rate == 0:
log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}')
print(log)
if use_scheduler:
embedding_scheduler.step()
decoder_scheduler.step()
if t % save_rate == 0:
torch.save(embedding_model.state_dict(), f'embedding_model_{version}_{t}.pth')
torch.save(decoder_model.state_dict(), f'decoder_model_{version}_{t}.pth')
Epoch: 10 - TL: 1.66e+00, VL: 1.81e+00 Epoch: 20 - TL: 1.69e+00, VL: 1.70e+00 Epoch: 30 - TL: 1.49e+00, VL: 1.73e+00 Epoch: 40 - TL: 1.35e+00, VL: 1.32e+00 Epoch: 50 - TL: 1.49e+00, VL: 1.57e+00 Epoch: 60 - TL: 1.31e+00, VL: 1.11e+00 Epoch: 70 - TL: 1.10e+00, VL: 1.53e+00 Epoch: 80 - TL: 9.98e-01, VL: 9.29e-01 Epoch: 90 - TL: 9.39e-01, VL: 6.90e-01 Epoch: 100 - TL: 6.67e-01, VL: 6.35e-01 Epoch: 110 - TL: 3.99e-01, VL: 2.24e-01 Epoch: 120 - TL: 4.65e-01, VL: 6.83e-01 Epoch: 130 - TL: 5.81e-01, VL: 1.07e+00 Epoch: 140 - TL: 5.59e-01, VL: 3.85e-01 Epoch: 150 - TL: 4.78e-01, VL: 9.67e-01 Epoch: 160 - TL: 6.60e-01, VL: 4.83e-01 Epoch: 170 - TL: 8.24e-01, VL: 9.62e-01 Epoch: 180 - TL: 5.35e-01, VL: 5.47e-01 Epoch: 190 - TL: 6.28e-01, VL: 4.52e-01 Epoch: 200 - TL: 5.28e-01, VL: 4.18e-01 Epoch: 210 - TL: 4.18e-01, VL: 3.12e-01 Epoch: 220 - TL: 4.54e-01, VL: 4.38e-01 Epoch: 230 - TL: 4.47e-01, VL: 5.75e-01 Epoch: 240 - TL: 4.17e-01, VL: 3.26e-01 Epoch: 250 - TL: 4.96e-01, VL: 4.66e-01 Epoch: 260 - TL: 4.82e-01, VL: 5.64e-01 Epoch: 270 - TL: 4.84e-01, VL: 3.90e-01 Epoch: 280 - TL: 4.43e-01, VL: 4.28e-01 Epoch: 290 - TL: 5.33e-01, VL: 3.92e-01 Epoch: 300 - TL: 4.07e-01, VL: 4.18e-01 Epoch: 310 - TL: 4.79e-01, VL: 7.54e-01 Epoch: 320 - TL: 4.52e-01, VL: 2.68e-01 Epoch: 330 - TL: 4.03e-01, VL: 5.31e-01 Epoch: 340 - TL: 4.07e-01, VL: 4.47e-01 Epoch: 350 - TL: 4.04e-01, VL: 3.60e-01 Epoch: 360 - TL: 4.05e-01, VL: 4.04e-01 Epoch: 370 - TL: 5.12e-01, VL: 4.04e-01 Epoch: 380 - TL: 4.13e-01, VL: 5.28e-01 Epoch: 390 - TL: 4.43e-01, VL: 3.59e-01 Epoch: 400 - TL: 3.66e-01, VL: 6.18e-01 Epoch: 410 - TL: 3.58e-01, VL: 5.61e-01 Epoch: 420 - TL: 5.20e-01, VL: 4.96e-01 Epoch: 430 - TL: 3.63e-01, VL: 4.08e-01 Epoch: 440 - TL: 3.77e-01, VL: 4.85e-01 Epoch: 450 - TL: 3.66e-01, VL: 3.51e-01 Epoch: 460 - TL: 3.56e-01, VL: 2.35e-01 Epoch: 470 - TL: 3.54e-01, VL: 4.04e-01 Epoch: 480 - TL: 3.63e-01, VL: 3.85e-01 Epoch: 490 - TL: 3.54e-01, VL: 3.52e-01 Epoch: 500 - TL: 4.04e-01, VL: 2.23e-01
save_model = True
if save_model:
torch.save(embedding_model.state_dict(), f'embedding_model_{version}.pth')
torch.save(decoder_model.state_dict(), f'decoder_model_{version}.pth')
fig,((ax1,ax2,ax3),(ax4,ax5,ax6)) = plt.subplots(2,3,figsize=(16,9))
try:
ctr = ctr.squeeze().cpu().detach().numpy()
qry = qry.squeeze().cpu().detach().numpy()
out = out.squeeze().cpu().detach().numpy()
goal = goal.squeeze().cpu().detach().numpy()
except AttributeError:
pass
gm = goal.max()
imshow(ctr[0,1,...], ax1, 'CTR', 0)
imshow(qry[0,1,...], ax2, 'QRY', 0)
ax3.axis('off')
imshow(out[0], ax4, 'OUT', 0, vmin=0, vmax=gm)
imshow(goal[0], ax5, 'HM', 0, vmin=0, vmax=gm)
imshow(np.abs(out[0]-goal[0]), ax6, 'DIFF', 0, vmin=0, vmax=gm)
def tidy_losses(train, valid):
out = {'epoch': [], 'type': [], 'value': [], 'phase': []}
for i, (tl,vl) in enumerate(zip(train,valid),1):
for tli in tl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(tli)
out['phase'].append('train')
for vli in vl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(vli)
out['phase'].append('valid')
return pd.DataFrame(out)
losses = tidy_losses(train_losses, valid_losses)
f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True)
sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3);
ax1.set_yscale('log');
ax1.set_title('Losses');
save_losses = False
if save_losses:
f.savefig(f'losses_{version}.pdf')
losses.to_csv(f'losses_{version}.csv')