#!/usr/bin/env python # coding: utf-8 # # Self-supervised learning in 3D images # # Use the Doersch-style method as described in [1] # # [1] M. Blendowski et al. "How to Learn from Unlabeled Volume Data: # Self-supervised 3D Context Feature Learning." MICCAI. 2019. # ## Setup notebook # In[1]: from typing import Callable, List, Optional, Tuple, Union from glob import glob import math import os import random import sys gpu_id = 0 os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}' import matplotlib.pyplot as plt import nibabel as nib import numpy as np import pandas as pd import seaborn as sns import torch from torch import nn import torch.nn.functional as F from torch.utils.data.dataset import Dataset from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler import torchvision from selfsupervised3d import * # Support in-notebook plotting # In[2]: get_ipython().run_line_magic('matplotlib', 'inline') # Report versions # In[3]: print('numpy version: {}'.format(np.__version__)) from matplotlib import __version__ as mplver print('matplotlib version: {}'.format(mplver)) print(f'pytorch version: {torch.__version__}') print(f'torchvision version: {torchvision.__version__}') # In[4]: pv = sys.version_info print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro)) # Reload packages where content for package development # In[5]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') # Check GPU(s) # In[6]: get_ipython().system('nvidia-smi') # In[7]: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True # Set seeds for better reproducibility. See [this note](https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers) before using multiprocessing. # In[8]: seed = 1336 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # ## Setup training and validation data # # Get the location of the training (and validation) data # In[9]: train_dir = '/iacl/pg20/jacobr/ixi/subsets/hh/' t1_dir = os.path.join(train_dir, 't1') t2_dir = os.path.join(train_dir, 't2') # In[10]: t1_fns = glob(os.path.join(t1_dir, '*.nii*')) t2_fns = glob(os.path.join(t2_dir, '*.nii*')) assert len(t1_fns) == len(t2_fns) and len(t1_fns) != 0 # ## Look at example training dataset # # Look at an axial view of the source T1-weighted (T1-w) and target T2-weighted (T2-w) images. # In[11]: def imshow(x, ax, title, n_rot=3): ax.imshow(np.rot90(x,n_rot), aspect='equal', cmap='gray') ax.set_title(title,fontsize=22) ax.axis('off') # In[12]: j = 100 t1_ex, t2_ex = nib.load(t1_fns[0]).get_data(), nib.load(t2_fns[0]).get_data() fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9)) imshow(t1_ex[...,j], ax1, 'T1', 1) imshow(t2_ex[...,j], ax2, 'T2', 1) # In[13]: x = torch.from_numpy(t1_ex).unsqueeze(0) # In[14]: (ctr, qry), goal = doersch_patches(x, patch_size=0.5, patch_dim=96) # In[15]: print(goal.item()) # In[16]: ctr = ctr.squeeze().cpu().detach().numpy() qry = qry.squeeze().cpu().detach().numpy() # In[17]: j = 12 fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9)) imshow(ctr[...,j], ax1, 'CTR', 1) imshow(qry[...,j], ax2, 'QRY', 1) # ## Setup training # # Hyperparameters, optimizers, logging, etc. # In[18]: data_dirs = [t1_dir] # In[19]: # system setup load_model = False # logging setup log_rate = 10 # print losses every log_rate epochs version = 'doersch_v1' # naming scheme of model to load save_rate = 100 # save models every save_rate epochs # model, optimizer, loss, and training parameters valid_split = 0.1 batch_size = 8 n_jobs = 8 n_epochs = 500 input_channels = len(data_dirs) descriptor_size = 192 use_adam = True opt_kwargs = dict(lr=1e-3, betas=(0.9,0.99), weight_decay=1e-6) if use_adam else \ dict(lr=5e-3, momentum=0.9) use_scheduler = True scheduler_kwargs = dict(step_size=100, gamma=0.5) # In[20]: def init_fn(worker_id): random.seed((torch.initial_seed() + worker_id) % (2**32)) np.random.seed((torch.initial_seed() + worker_id) % (2**32)) # In[21]: # setup training and validation dataloaders dataset = DoerschDataset(data_dirs) num_train = len(dataset) indices = list(range(num_train)) split = int(valid_split * num_train) valid_idx = np.random.choice(indices, size=split, replace=False) train_idx = list(set(indices) - set(valid_idx)) train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, worker_init_fn=init_fn, num_workers=n_jobs, pin_memory=True, collate_fn=doersch_collate) valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size, worker_init_fn=init_fn, num_workers=n_jobs, pin_memory=True, collate_fn=doersch_collate) # In[22]: print(f'Number of training images: {num_train-split}') print(f'Number of validation images: {split}') # In[23]: embedding_model = DoerschNet(input_channels=input_channels, descriptor_size=descriptor_size) decoder_model = DoerschDecodeNet(descriptor_size=descriptor_size) # In[24]: def num_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # In[25]: print(f'Number of trainable parameters in embedding model: {num_params(embedding_model)}') print(f'Number of trainable parameters in decoder model: {num_params(decoder_model)}') # In[26]: if load_model: embedding_model.load_state_dict(torch.load(f'embedding_model_{version}.pth')) decoder_model.load_state_dict(torch.load(f'decoder_model_{version}.pth')) # In[27]: embedding_model.to(device) decoder_model.to(device) optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD embedding_opt = optim_cls(embedding_model.parameters(), **opt_kwargs) decoder_opt = optim_cls(decoder_model.parameters(), **opt_kwargs) criterion = nn.CrossEntropyLoss() if use_scheduler: embedding_scheduler = torch.optim.lr_scheduler.StepLR(embedding_opt, **scheduler_kwargs) decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, **scheduler_kwargs) # ## Train model # In[28]: train_losses, valid_losses = [], [] n_batches = len(train_loader) # In[29]: for t in range(1, n_epochs + 1): # training t_losses = [] embedding_model.train() decoder_model.train() for i, ((ctr, qry), goal) in enumerate(train_loader): ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device) embedding_opt.zero_grad() decoder_opt.zero_grad() ctr_f = embedding_model(ctr) qry_f = embedding_model(qry) out = decoder_model(ctr_f, qry_f) loss = criterion(out, goal) t_losses.append(loss.item()) loss.backward() embedding_opt.step() decoder_opt.step() train_losses.append(t_losses) # validation v_losses = [] embedding_model.eval() decoder_model.eval() with torch.no_grad(): for i, ((ctr, qry), goal) in enumerate(valid_loader): ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device) ctr_f = embedding_model(ctr) qry_f = embedding_model(qry) out = decoder_model(ctr_f, qry_f) loss = criterion(out, goal) v_losses.append(loss.item()) valid_losses.append(v_losses) # log, step scheduler, and save results from epoch if not np.all(np.isfinite(t_losses)): raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.') if t % log_rate == 0: log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}') print(log) if use_scheduler: embedding_scheduler.step() decoder_scheduler.step() if t % save_rate == 0: torch.save(embedding_model.state_dict(), f'embedding_model_{version}_{t}.pth') torch.save(decoder_model.state_dict(), f'decoder_model_{version}_{t}.pth') # In[30]: save_model = True if save_model: torch.save(embedding_model.state_dict(), f'embedding_model_{version}.pth') torch.save(decoder_model.state_dict(), f'decoder_model_{version}.pth') # ## Analyze training # In[31]: def tidy_losses(train, valid): out = {'epoch': [], 'type': [], 'value': [], 'phase': []} for i, (tl,vl) in enumerate(zip(train,valid),1): for tli in tl: out['epoch'].append(i) out['type'].append('loss') out['value'].append(tli) out['phase'].append('train') for vli in vl: out['epoch'].append(i) out['type'].append('loss') out['value'].append(vli) out['phase'].append('valid') return pd.DataFrame(out) # In[32]: losses = tidy_losses(train_losses, valid_losses) # In[33]: f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True) sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3); ax1.set_yscale('log'); ax1.set_title('Losses'); # In[34]: save_losses = False if save_losses: f.savefig(f'losses_{version}.pdf') losses.to_csv(f'losses_{version}.csv')