NiftyNet is "an open source convolutional neural networks platform for medical image analysis and image-guided therapy" built on top of TensorFlow. Due to its available implementations of successful architectures, patch-based sampling and straightforward configuration, it has become a popular choice to get started with deep learning in medical imaging.
PyTorch is "an open source deep learning platform that provides a seamless path from research prototyping to production deployment". It is low-level enough to offer a lot of control over what is going on under the hood during training, and its dynamic computational graph allows for easy debugging. Being a generic deep learning framework, it is not tailored to the needs of the medical imaging field, although its popularity in this field is increasing rapidly.
One can extend a NiftyNet application, but it is not straightforward without being familiar with the framework and fluent in TensorFlow 1.X. Therefore it can be convenient to implement applications in PyTorch using NiftyNet models and functionalities. In particular, combining both frameworks allows for fast architecture experimentation and transfer learning.
So why not use both? In this tutorial we will port the parameters of a model trained on NiftyNet to a PyTorch model and compare the results of running an inference using both frameworks.
Although NiftyNet supports different applications, it is mostly used for medical image segmentation.
*Image segmentation using deep learning* were the 5 most common words in all full paper titles from both MICCAI 2018 and MIDL 2019, which shows the interest of the medical imaging community in the topic.
"Image segmentation using deep learning", guess this is the hottest topic in MIDL #MIDL2019 @midl_conference pic.twitter.com/64smdMjnxY
— Hua Ma (@forever_pippo) July 8, 2019
Just like #miccai2018! pic.twitter.com/3ZTHxj9iPT
— Julia Schnabel (@ja_schnabel) July 8, 2019
HighRes3DNet is a residual convolutional neural network architecture designed to have a large receptive field and preserve a high resolution using a relatively small number of parameters. It was presented in 2017 by Li et al. at IPMI: On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task.
The authors used NiftyNet to implement and train a model based on this architecture to perform brain parcellation using $T_1$-weighted MR images from the ADNI dataset. They achieved competitive segmentation performance compared with state-of-the-art architectures such as DeepMedic or U-Net.
This figure from the paper shows a parcellation produced by HighRes3DNet:
The code of the architecture is on NiftyNet GitHub repository. The authors have uploaded the parameters and configuration file to the Model Zoo.
After reading the paper and the code, it is relatively straightforward to implement the same architecture using PyTorch.
All the code is hosted in a GitHub repository:
fepegar/miccai-educational-challenge-2019
.
The latest release can also be found on the Zenodo repository under this DOI: 10.5281/zenodo.3352316.
If you have a Google account, the best way to run this notebook seamlessly is using Google Colab. You will need to click on "Open in playground", at the top left:
You will also get a couple of warnings that you can safely ignore. The first one warns about this notebook not being authored by Google and the second one asks for confirmation to reset all runtimes. These are valid points, but will not affect this tutorial.
Please report any issues on GitHub and I will fix them. You can also drop me an email if you have any questions or comments.
To write this notebook I used Ubuntu 18.04 installed on an Alienware 13 R3 laptop, which includes a 6-GB GeForce GTX 1060 NVIDIA GPU. I am using CUDA 9.0.
Inference using PyTorch took 5725 MB of GPU memory. TensorFlow usually takes as much as possible beforehand.
To run this notebook locally, I recommend downloading the repository and creating a conda
environment:
git clone https://github.com/fepegar/miccai-educational-challenge-2019.git
cd miccai-educational-challenge-2019
conda create -n mec2019 python=3.6 -y
conda activate mec2019 && conda install jupyterlab -y && jupyter lab
An already executed version of the notebook can be rendered using nbviewer.
If you run the notebook, you can use interactive plots to navigate through the volume slices by setting this variable to True
. You might need to run the volume visualization cells individually after running the whole notebook. This feature is experimental and therefore disabled by default.
interactive_plots = False
Clone NiftyNet and some custom Python libraries for this notebook. This might take one or two minutes.
%%capture --no-stderr
# This might take about 30 seconds
!rm -rf NiftyNet && git clone https://github.com/NifTK/NiftyNet --depth 1
!cd NiftyNet && git checkout df0f86733357fdc92bbc191c8fec0dcf49aa5499 && cd ..
!pip install -r NiftyNet/requirements-gpu.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/requirements.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/tf2pt.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/utils.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/highresnet_mapping.py
!curl -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl
!pip install -r requirements.txt
!pip install --upgrade numpy
!pip install ipywidgets
import sys
sys.path.insert(0, 'NiftyNet')
%matplotlib inline
%config InlineBackend.figure_format='retina'
import os
import tempfile
from pathlib import Path
from configparser import ConfigParser
import numpy as np
import pandas as pd
try:
# Fancy rendering of Pandas tables
import google.colab.data_table
%load_ext google.colab.data_table
print("We are on Google Colab")
except ModuleNotFoundError:
print("We are not on Google Colab")
pd.set_option('display.max_colwidth', -1) # do not truncate strings when displaying data frames
pd.set_option('display.max_rows', None) # show all rows
import torch
from highresnet import HighRes3DNet
We are on Google Colab
%%capture
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
import tf2pt
import utils
import visualization
import highresnet_mapping
if interactive_plots: # for Colab or Jupyter
plot_volume = visualization.plot_volume_interactive
else: # for HTML, GitHub or nbviewer
plot_volume = visualization.plot_volume
from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_grid_v2 import GridSampler
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.layer.pad import PadLayer
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.histogram_normalisation import HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import MeanVarNormalisationLayer
We can use NiftyNet's net_download
to get all we need from the Model Zoo entry corresponding to brain parcellation using HighRes3DNet:
%%capture
%run NiftyNet/net_download.py highres3dnet_brain_parcellation_model_zoo
niftynet_dir = Path('~/niftynet').expanduser()
utils.list_files(niftynet_dir)
niftynet/ data/ OASIS/ license OAS1_0145_MR2_mpr_n4_anon_sbj_111.nii.gz models/ highres3dnet_brain_parcellation/ inference_niftynet_log databrain_std_hist_models_otsu.txt settings_inference.txt Modality0.csv parcellation_output/ window_seg_OAS1_0145_MR2_mpr_n4_anon_sbj_111__niftynet_out.nii.gz inferred.csv logs/ models/ model.ckpt-33000.meta model.ckpt-33000.data-00000-of-00001 model.ckpt-33000.index extensions/ __init__.py highres3dnet_brain_parcellation/ __init__.py highres3dnet_config_eval.ini network/ __init__.py
There are three directories under ~/niftynet
:
extensions
is a Python package that contains the [configuration file].(https://niftynet.readthedocs.io/en/dev/config_spec.html)models
contains the landmarks for histogram standardization (an MRI preprocessing step) and the network parameters.data
contains an OASIS MRI sample that can be used to test the model.Here are the paths to the downloaded files and to the files that will be generated by the notebook. I use nn
for NiftyNet, tf
for TensorFlow and pt
for PyTorch.
models_dir = niftynet_dir / 'models'
zoo_entry = 'highres3dnet_brain_parcellation'
input_checkpoint_tf_name = 'model.ckpt-33000'
input_checkpoint_tf_path = models_dir / zoo_entry / 'models' / input_checkpoint_tf_name
data_dir = niftynet_dir / 'data' / 'OASIS'
config_path = niftynet_dir / 'extensions' / zoo_entry / 'highres3dnet_config_eval.ini'
histogram_landmarks_path = models_dir / zoo_entry / 'databrain_std_hist_models_otsu.txt'
tempdir = Path(tempfile.gettempdir()) / 'miccai_niftynet_pytorch'
tempdir.mkdir(exist_ok=True)
output_csv_tf_path = tempdir / 'variables_tf.csv'
output_state_dict_tf_path = tempdir / 'state_dict_tf.pth'
output_state_dict_pt_path = tempdir / 'state_dict_pt.pth'
prediction_pt_dir = tempdir / 'prediction'
prediction_pt_dir.mkdir(exist_ok=True)
color_table_path = 'GIFNiftyNet.ctbl'
Note that the path to the checkpoint is not a path to an existing filename, but the basename of the three checkpoint files.
There are two modules that are relevant for this section in the
repository.
tf2pt
contains generic functions that can be used to transform any TensorFlow model to PyTorch.
highresnet_mapping
contains custom functions that are specific for HighRes3DNet.
Let's see what variables are stored in the checkpoint.
These are filtered out by
highresnet_mapping.is_not_valid()
for clarity:
biased
or ExponentialMovingAverage
. I have experimented with them and the results using these variables turned out to be different to the ones generated by NiftyNet.We will store the variables names in a data frame to list them in this notebook and the values in a Python dictionary to retrieve them later. I figured out the code in
tf2pt.checkpoint_tf_to_state_dict_tf()
reading the corresponding TensorFlow docs and Stack Overflow answers.
# %%capture
tf2pt.checkpoint_tf_to_state_dict_tf(
input_checkpoint_tf_path=input_checkpoint_tf_path,
output_csv_tf_path=output_csv_tf_path,
output_state_dict_tf_path=output_state_dict_tf_path,
filter_out_function=highresnet_mapping.is_not_valid,
replace_string='HighRes3DNet/',
)
data_frame_tf = pd.read_csv(output_csv_tf_path)
state_dict_tf = torch.load(output_state_dict_tf_path)
W0824 12:06:24.860505 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:106: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead. W0824 12:06:24.872821 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:114: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead. W0824 12:06:25.661258 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:122: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead. W0824 12:06:25.749863 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:124: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead. I0824 12:06:29.631742 140162299176832 saver.py:1280] Restoring parameters from /root/niftynet/models/highres3dnet_brain_parcellation/models/model.ckpt-33000
data_frame_tf
Unnamed: 0 | name | shape | |
---|---|---|---|
0 | 0 | conv_0_bn_relu/bn_/beta | 16 |
1 | 1 | conv_0_bn_relu/bn_/gamma | 16 |
2 | 2 | conv_0_bn_relu/bn_/moving_mean | 16 |
3 | 3 | conv_0_bn_relu/bn_/moving_variance | 16 |
4 | 4 | conv_0_bn_relu/conv_/w | 3, 3, 3, 1, 16 |
5 | 5 | conv_1_bn_relu/bn_/beta | 80 |
6 | 6 | conv_1_bn_relu/bn_/gamma | 80 |
7 | 7 | conv_1_bn_relu/bn_/moving_mean | 80 |
8 | 8 | conv_1_bn_relu/bn_/moving_variance | 80 |
9 | 9 | conv_1_bn_relu/conv_/w | 1, 1, 1, 64, 80 |
10 | 10 | conv_2_bn/bn_/beta | 160 |
11 | 11 | conv_2_bn/bn_/gamma | 160 |
12 | 12 | conv_2_bn/bn_/moving_mean | 160 |
13 | 13 | conv_2_bn/bn_/moving_variance | 160 |
14 | 14 | conv_2_bn/conv_/w | 1, 1, 1, 80, 160 |
15 | 15 | res_1_0/bn_0/beta | 16 |
16 | 16 | res_1_0/bn_0/gamma | 16 |
17 | 17 | res_1_0/bn_0/moving_mean | 16 |
18 | 18 | res_1_0/bn_0/moving_variance | 16 |
19 | 19 | res_1_0/bn_1/beta | 16 |
20 | 20 | res_1_0/bn_1/gamma | 16 |
21 | 21 | res_1_0/bn_1/moving_mean | 16 |
22 | 22 | res_1_0/bn_1/moving_variance | 16 |
23 | 23 | res_1_0/conv_0/w | 3, 3, 3, 16, 16 |
24 | 24 | res_1_0/conv_1/w | 3, 3, 3, 16, 16 |
25 | 25 | res_1_1/bn_0/beta | 16 |
26 | 26 | res_1_1/bn_0/gamma | 16 |
27 | 27 | res_1_1/bn_0/moving_mean | 16 |
28 | 28 | res_1_1/bn_0/moving_variance | 16 |
29 | 29 | res_1_1/bn_1/beta | 16 |
... | ... | ... | ... |
75 | 75 | res_3_0/bn_0/beta | 32 |
76 | 76 | res_3_0/bn_0/gamma | 32 |
77 | 77 | res_3_0/bn_0/moving_mean | 32 |
78 | 78 | res_3_0/bn_0/moving_variance | 32 |
79 | 79 | res_3_0/bn_1/beta | 64 |
80 | 80 | res_3_0/bn_1/gamma | 64 |
81 | 81 | res_3_0/bn_1/moving_mean | 64 |
82 | 82 | res_3_0/bn_1/moving_variance | 64 |
83 | 83 | res_3_0/conv_0/w | 3, 3, 3, 32, 64 |
84 | 84 | res_3_0/conv_1/w | 3, 3, 3, 64, 64 |
85 | 85 | res_3_1/bn_0/beta | 64 |
86 | 86 | res_3_1/bn_0/gamma | 64 |
87 | 87 | res_3_1/bn_0/moving_mean | 64 |
88 | 88 | res_3_1/bn_0/moving_variance | 64 |
89 | 89 | res_3_1/bn_1/beta | 64 |
90 | 90 | res_3_1/bn_1/gamma | 64 |
91 | 91 | res_3_1/bn_1/moving_mean | 64 |
92 | 92 | res_3_1/bn_1/moving_variance | 64 |
93 | 93 | res_3_1/conv_0/w | 3, 3, 3, 64, 64 |
94 | 94 | res_3_1/conv_1/w | 3, 3, 3, 64, 64 |
95 | 95 | res_3_2/bn_0/beta | 64 |
96 | 96 | res_3_2/bn_0/gamma | 64 |
97 | 97 | res_3_2/bn_0/moving_mean | 64 |
98 | 98 | res_3_2/bn_0/moving_variance | 64 |
99 | 99 | res_3_2/bn_1/beta | 64 |
100 | 100 | res_3_2/bn_1/gamma | 64 |
101 | 101 | res_3_2/bn_1/moving_mean | 64 |
102 | 102 | res_3_2/bn_1/moving_variance | 64 |
103 | 103 | res_3_2/conv_0/w | 3, 3, 3, 64, 64 |
104 | 104 | res_3_2/conv_1/w | 3, 3, 3, 64, 64 |
105 rows × 3 columns
The weight parameters associated with each convolutional layer, denoted with conv_/w
, are stored with shape representing the three spatial dimensions, the input channels and the output channels: $(Depth, Height, Width, Channels_{in}, Channels_{out})$. Calling the spatial dimensions depth, height and width does not make a lot of sense when dealing with 3D medical images, but we will keep this computer vision terminology as it is consistent with the documentation of both PyTorch and TensorFlow.
The layer names and parameter shapes are coherent overall with the figure in the HighRes3DNet paper, but there is an additional $1 \times 1 \times 1$ convolutional layer with 80 output channels, which is also in the code. It seems to be the model with dropout from the paper that achieved the highest performance, so our implementation of the architecture should include this layer as well.
There are three blocks with increasing kernel dilation composed of three residual blocks each, which have two convolutional layers inside. That's $3 \times 3 \times 2 = 18$ layers. The other three convolutional layers are the initial convolution before the first residual block, a convolution before dropout and a convolution to expand the channels to the number of output classes.
Apparently, all the convolutional layers have an associated batch normalization layer, which differs from the figure in the paper. That makes 21 convolutional layers and 21 batch normalization layers whose parameters must be transferred.
Each batch normalization layer contains 4 parameter groups: moving mean $\mathrm{E}[x]$, variance $\mathrm{Var}[x]$ and the affine transformation parameters $\gamma$ (scale or weight) and $\beta$ (shift or bias):
Therefore the total number of parameter groups is $21 + 21 \times 4 = 105$. The convolutional layers don't use a bias parameter, as it is not necessary when using batch norm.
To match the model in the [paper]((https://arxiv.org/abs/1707.01992), we set the number of classes to 160 and enable the flag to add the dropout layer.
num_input_modalities = 1
num_classes = 160
model = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)
Let's see what the variable names created by PyTorch are:
state_dict_pt = model.state_dict()
rows = []
for name, parameters in state_dict_pt.items():
if 'num_batches_tracked' in name: # filter out for clarity
continue
shape = ', '.join(str(n) for n in parameters.shape)
row = {'name': name, 'shape': shape}
rows.append(row)
df_pt = pd.DataFrame.from_dict(rows)
df_pt
name | shape | |
---|---|---|
0 | block.0.convolutional_block.1.weight | 16, 1, 3, 3, 3 |
1 | block.0.convolutional_block.2.weight | 16 |
2 | block.0.convolutional_block.2.bias | 16 |
3 | block.0.convolutional_block.2.running_mean | 16 |
4 | block.0.convolutional_block.2.running_var | 16 |
5 | block.1.dilation_block.0.residual_block.0.conv... | 16 |
6 | block.1.dilation_block.0.residual_block.0.conv... | 16 |
7 | block.1.dilation_block.0.residual_block.0.conv... | 16 |
8 | block.1.dilation_block.0.residual_block.0.conv... | 16 |
9 | block.1.dilation_block.0.residual_block.0.conv... | 16, 16, 3, 3, 3 |
10 | block.1.dilation_block.0.residual_block.1.conv... | 16 |
11 | block.1.dilation_block.0.residual_block.1.conv... | 16 |
12 | block.1.dilation_block.0.residual_block.1.conv... | 16 |
13 | block.1.dilation_block.0.residual_block.1.conv... | 16 |
14 | block.1.dilation_block.0.residual_block.1.conv... | 16, 16, 3, 3, 3 |
15 | block.1.dilation_block.1.residual_block.0.conv... | 16 |
16 | block.1.dilation_block.1.residual_block.0.conv... | 16 |
17 | block.1.dilation_block.1.residual_block.0.conv... | 16 |
18 | block.1.dilation_block.1.residual_block.0.conv... | 16 |
19 | block.1.dilation_block.1.residual_block.0.conv... | 16, 16, 3, 3, 3 |
20 | block.1.dilation_block.1.residual_block.1.conv... | 16 |
21 | block.1.dilation_block.1.residual_block.1.conv... | 16 |
22 | block.1.dilation_block.1.residual_block.1.conv... | 16 |
23 | block.1.dilation_block.1.residual_block.1.conv... | 16 |
24 | block.1.dilation_block.1.residual_block.1.conv... | 16, 16, 3, 3, 3 |
25 | block.1.dilation_block.2.residual_block.0.conv... | 16 |
26 | block.1.dilation_block.2.residual_block.0.conv... | 16 |
27 | block.1.dilation_block.2.residual_block.0.conv... | 16 |
28 | block.1.dilation_block.2.residual_block.0.conv... | 16 |
29 | block.1.dilation_block.2.residual_block.0.conv... | 16, 16, 3, 3, 3 |
... | ... | ... |
75 | block.3.dilation_block.1.residual_block.0.conv... | 64 |
76 | block.3.dilation_block.1.residual_block.0.conv... | 64 |
77 | block.3.dilation_block.1.residual_block.0.conv... | 64 |
78 | block.3.dilation_block.1.residual_block.0.conv... | 64 |
79 | block.3.dilation_block.1.residual_block.0.conv... | 64, 64, 3, 3, 3 |
80 | block.3.dilation_block.1.residual_block.1.conv... | 64 |
81 | block.3.dilation_block.1.residual_block.1.conv... | 64 |
82 | block.3.dilation_block.1.residual_block.1.conv... | 64 |
83 | block.3.dilation_block.1.residual_block.1.conv... | 64 |
84 | block.3.dilation_block.1.residual_block.1.conv... | 64, 64, 3, 3, 3 |
85 | block.3.dilation_block.2.residual_block.0.conv... | 64 |
86 | block.3.dilation_block.2.residual_block.0.conv... | 64 |
87 | block.3.dilation_block.2.residual_block.0.conv... | 64 |
88 | block.3.dilation_block.2.residual_block.0.conv... | 64 |
89 | block.3.dilation_block.2.residual_block.0.conv... | 64, 64, 3, 3, 3 |
90 | block.3.dilation_block.2.residual_block.1.conv... | 64 |
91 | block.3.dilation_block.2.residual_block.1.conv... | 64 |
92 | block.3.dilation_block.2.residual_block.1.conv... | 64 |
93 | block.3.dilation_block.2.residual_block.1.conv... | 64 |
94 | block.3.dilation_block.2.residual_block.1.conv... | 64, 64, 3, 3, 3 |
95 | block.4.convolutional_block.0.weight | 80, 64, 1, 1, 1 |
96 | block.4.convolutional_block.1.weight | 80 |
97 | block.4.convolutional_block.1.bias | 80 |
98 | block.4.convolutional_block.1.running_mean | 80 |
99 | block.4.convolutional_block.1.running_var | 80 |
100 | block.6.convolutional_block.0.weight | 160, 80, 1, 1, 1 |
101 | block.6.convolutional_block.1.weight | 160 |
102 | block.6.convolutional_block.1.bias | 160 |
103 | block.6.convolutional_block.1.running_mean | 160 |
104 | block.6.convolutional_block.1.running_var | 160 |
105 rows × 2 columns
We can see that moving_mean
and moving_variance
are called running_mean
and running_var
in PyTorch. Also, $\gamma$ and $\beta$ are called weight
and bias
.
The convolutional kernels have a different arrangement: $(Channels_{out}, Channels_{in}, Depth, Height, Width)$.
The names and shapes look consistent between both implementations and there are 105 lines in both lists, therefore we should be able to create a mapping between the TensorFlow and PyTorch variables. The function tf2pt.tf2pt()
receives a TensorFlow-like variable and returns the corresponding PyTorch-like variable.
for name_tf, tensor_tf in state_dict_tf.items():
shape_tf = tuple(tensor_tf.shape)
print(f'{str(shape_tf):18}', name_tf)
# Convert TensorFlow name to PyTorch name
mapping_function = highresnet_mapping.tf2pt_name
name_pt, tensor_pt = tf2pt.tf2pt(name_tf, tensor_tf, mapping_function)
shape_pt = tuple(state_dict_pt[name_pt].shape)
print(f'{str(shape_pt):18}', name_pt)
# Sanity check
if sum(shape_tf) != sum(shape_pt):
raise ValueError
# Add weights to PyTorch state dictionary
state_dict_pt[name_pt] = tensor_pt
print()
torch.save(state_dict_pt, output_state_dict_pt_path)
print('State dictionary saved to', output_state_dict_pt_path)
(16,) conv_0_bn_relu/bn_/beta (16,) block.0.convolutional_block.2.bias (16,) conv_0_bn_relu/bn_/gamma (16,) block.0.convolutional_block.2.weight (16,) conv_0_bn_relu/bn_/moving_mean (16,) block.0.convolutional_block.2.running_mean (16,) conv_0_bn_relu/bn_/moving_variance (16,) block.0.convolutional_block.2.running_var (3, 3, 3, 1, 16) conv_0_bn_relu/conv_/w (16, 1, 3, 3, 3) block.0.convolutional_block.1.weight (80,) conv_1_bn_relu/bn_/beta (80,) block.4.convolutional_block.1.bias (80,) conv_1_bn_relu/bn_/gamma (80,) block.4.convolutional_block.1.weight (80,) conv_1_bn_relu/bn_/moving_mean (80,) block.4.convolutional_block.1.running_mean (80,) conv_1_bn_relu/bn_/moving_variance (80,) block.4.convolutional_block.1.running_var (1, 1, 1, 64, 80) conv_1_bn_relu/conv_/w (80, 64, 1, 1, 1) block.4.convolutional_block.0.weight (160,) conv_2_bn/bn_/beta (160,) block.6.convolutional_block.1.bias (160,) conv_2_bn/bn_/gamma (160,) block.6.convolutional_block.1.weight (160,) conv_2_bn/bn_/moving_mean (160,) block.6.convolutional_block.1.running_mean (160,) conv_2_bn/bn_/moving_variance (160,) block.6.convolutional_block.1.running_var (1, 1, 1, 80, 160) conv_2_bn/conv_/w (160, 80, 1, 1, 1) block.6.convolutional_block.0.weight (16,) res_1_0/bn_0/beta (16,) block.1.dilation_block.0.residual_block.0.convolutional_block.0.bias (16,) res_1_0/bn_0/gamma (16,) block.1.dilation_block.0.residual_block.0.convolutional_block.0.weight (16,) res_1_0/bn_0/moving_mean (16,) block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_mean (16,) res_1_0/bn_0/moving_variance (16,) block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_var (16,) res_1_0/bn_1/beta (16,) block.1.dilation_block.0.residual_block.1.convolutional_block.0.bias (16,) res_1_0/bn_1/gamma (16,) block.1.dilation_block.0.residual_block.1.convolutional_block.0.weight (16,) res_1_0/bn_1/moving_mean (16,) block.1.dilation_block.0.residual_block.1.convolutional_block.0.running_mean (16,) res_1_0/bn_1/moving_variance (16,) block.1.dilation_block.0.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 16, 16) res_1_0/conv_0/w (16, 16, 3, 3, 3) block.1.dilation_block.0.residual_block.0.convolutional_block.3.weight (3, 3, 3, 16, 16) res_1_0/conv_1/w (16, 16, 3, 3, 3) block.1.dilation_block.0.residual_block.1.convolutional_block.3.weight (16,) res_1_1/bn_0/beta (16,) block.1.dilation_block.1.residual_block.0.convolutional_block.0.bias (16,) res_1_1/bn_0/gamma (16,) block.1.dilation_block.1.residual_block.0.convolutional_block.0.weight (16,) res_1_1/bn_0/moving_mean (16,) block.1.dilation_block.1.residual_block.0.convolutional_block.0.running_mean (16,) res_1_1/bn_0/moving_variance (16,) block.1.dilation_block.1.residual_block.0.convolutional_block.0.running_var (16,) res_1_1/bn_1/beta (16,) block.1.dilation_block.1.residual_block.1.convolutional_block.0.bias (16,) res_1_1/bn_1/gamma (16,) block.1.dilation_block.1.residual_block.1.convolutional_block.0.weight (16,) res_1_1/bn_1/moving_mean (16,) block.1.dilation_block.1.residual_block.1.convolutional_block.0.running_mean (16,) res_1_1/bn_1/moving_variance (16,) block.1.dilation_block.1.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 16, 16) res_1_1/conv_0/w (16, 16, 3, 3, 3) block.1.dilation_block.1.residual_block.0.convolutional_block.3.weight (3, 3, 3, 16, 16) res_1_1/conv_1/w (16, 16, 3, 3, 3) block.1.dilation_block.1.residual_block.1.convolutional_block.3.weight (16,) res_1_2/bn_0/beta (16,) block.1.dilation_block.2.residual_block.0.convolutional_block.0.bias (16,) res_1_2/bn_0/gamma (16,) block.1.dilation_block.2.residual_block.0.convolutional_block.0.weight (16,) res_1_2/bn_0/moving_mean (16,) block.1.dilation_block.2.residual_block.0.convolutional_block.0.running_mean (16,) res_1_2/bn_0/moving_variance (16,) block.1.dilation_block.2.residual_block.0.convolutional_block.0.running_var (16,) res_1_2/bn_1/beta (16,) block.1.dilation_block.2.residual_block.1.convolutional_block.0.bias (16,) res_1_2/bn_1/gamma (16,) block.1.dilation_block.2.residual_block.1.convolutional_block.0.weight (16,) res_1_2/bn_1/moving_mean (16,) block.1.dilation_block.2.residual_block.1.convolutional_block.0.running_mean (16,) res_1_2/bn_1/moving_variance (16,) block.1.dilation_block.2.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 16, 16) res_1_2/conv_0/w (16, 16, 3, 3, 3) block.1.dilation_block.2.residual_block.0.convolutional_block.3.weight (3, 3, 3, 16, 16) res_1_2/conv_1/w (16, 16, 3, 3, 3) block.1.dilation_block.2.residual_block.1.convolutional_block.3.weight (16,) res_2_0/bn_0/beta (16,) block.2.dilation_block.0.residual_block.0.convolutional_block.0.bias (16,) res_2_0/bn_0/gamma (16,) block.2.dilation_block.0.residual_block.0.convolutional_block.0.weight (16,) res_2_0/bn_0/moving_mean (16,) block.2.dilation_block.0.residual_block.0.convolutional_block.0.running_mean (16,) res_2_0/bn_0/moving_variance (16,) block.2.dilation_block.0.residual_block.0.convolutional_block.0.running_var (32,) res_2_0/bn_1/beta (32,) block.2.dilation_block.0.residual_block.1.convolutional_block.0.bias (32,) res_2_0/bn_1/gamma (32,) block.2.dilation_block.0.residual_block.1.convolutional_block.0.weight (32,) res_2_0/bn_1/moving_mean (32,) block.2.dilation_block.0.residual_block.1.convolutional_block.0.running_mean (32,) res_2_0/bn_1/moving_variance (32,) block.2.dilation_block.0.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 16, 32) res_2_0/conv_0/w (32, 16, 3, 3, 3) block.2.dilation_block.0.residual_block.0.convolutional_block.3.weight (3, 3, 3, 32, 32) res_2_0/conv_1/w (32, 32, 3, 3, 3) block.2.dilation_block.0.residual_block.1.convolutional_block.3.weight (32,) res_2_1/bn_0/beta (32,) block.2.dilation_block.1.residual_block.0.convolutional_block.0.bias (32,) res_2_1/bn_0/gamma (32,) block.2.dilation_block.1.residual_block.0.convolutional_block.0.weight (32,) res_2_1/bn_0/moving_mean (32,) block.2.dilation_block.1.residual_block.0.convolutional_block.0.running_mean (32,) res_2_1/bn_0/moving_variance (32,) block.2.dilation_block.1.residual_block.0.convolutional_block.0.running_var (32,) res_2_1/bn_1/beta (32,) block.2.dilation_block.1.residual_block.1.convolutional_block.0.bias (32,) res_2_1/bn_1/gamma (32,) block.2.dilation_block.1.residual_block.1.convolutional_block.0.weight (32,) res_2_1/bn_1/moving_mean (32,) block.2.dilation_block.1.residual_block.1.convolutional_block.0.running_mean (32,) res_2_1/bn_1/moving_variance (32,) block.2.dilation_block.1.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 32, 32) res_2_1/conv_0/w (32, 32, 3, 3, 3) block.2.dilation_block.1.residual_block.0.convolutional_block.3.weight (3, 3, 3, 32, 32) res_2_1/conv_1/w (32, 32, 3, 3, 3) block.2.dilation_block.1.residual_block.1.convolutional_block.3.weight (32,) res_2_2/bn_0/beta (32,) block.2.dilation_block.2.residual_block.0.convolutional_block.0.bias (32,) res_2_2/bn_0/gamma (32,) block.2.dilation_block.2.residual_block.0.convolutional_block.0.weight (32,) res_2_2/bn_0/moving_mean (32,) block.2.dilation_block.2.residual_block.0.convolutional_block.0.running_mean (32,) res_2_2/bn_0/moving_variance (32,) block.2.dilation_block.2.residual_block.0.convolutional_block.0.running_var (32,) res_2_2/bn_1/beta (32,) block.2.dilation_block.2.residual_block.1.convolutional_block.0.bias (32,) res_2_2/bn_1/gamma (32,) block.2.dilation_block.2.residual_block.1.convolutional_block.0.weight (32,) res_2_2/bn_1/moving_mean (32,) block.2.dilation_block.2.residual_block.1.convolutional_block.0.running_mean (32,) res_2_2/bn_1/moving_variance (32,) block.2.dilation_block.2.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 32, 32) res_2_2/conv_0/w (32, 32, 3, 3, 3) block.2.dilation_block.2.residual_block.0.convolutional_block.3.weight (3, 3, 3, 32, 32) res_2_2/conv_1/w (32, 32, 3, 3, 3) block.2.dilation_block.2.residual_block.1.convolutional_block.3.weight (32,) res_3_0/bn_0/beta (32,) block.3.dilation_block.0.residual_block.0.convolutional_block.0.bias (32,) res_3_0/bn_0/gamma (32,) block.3.dilation_block.0.residual_block.0.convolutional_block.0.weight (32,) res_3_0/bn_0/moving_mean (32,) block.3.dilation_block.0.residual_block.0.convolutional_block.0.running_mean (32,) res_3_0/bn_0/moving_variance (32,) block.3.dilation_block.0.residual_block.0.convolutional_block.0.running_var (64,) res_3_0/bn_1/beta (64,) block.3.dilation_block.0.residual_block.1.convolutional_block.0.bias (64,) res_3_0/bn_1/gamma (64,) block.3.dilation_block.0.residual_block.1.convolutional_block.0.weight (64,) res_3_0/bn_1/moving_mean (64,) block.3.dilation_block.0.residual_block.1.convolutional_block.0.running_mean (64,) res_3_0/bn_1/moving_variance (64,) block.3.dilation_block.0.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 32, 64) res_3_0/conv_0/w (64, 32, 3, 3, 3) block.3.dilation_block.0.residual_block.0.convolutional_block.3.weight (3, 3, 3, 64, 64) res_3_0/conv_1/w (64, 64, 3, 3, 3) block.3.dilation_block.0.residual_block.1.convolutional_block.3.weight (64,) res_3_1/bn_0/beta (64,) block.3.dilation_block.1.residual_block.0.convolutional_block.0.bias (64,) res_3_1/bn_0/gamma (64,) block.3.dilation_block.1.residual_block.0.convolutional_block.0.weight (64,) res_3_1/bn_0/moving_mean (64,) block.3.dilation_block.1.residual_block.0.convolutional_block.0.running_mean (64,) res_3_1/bn_0/moving_variance (64,) block.3.dilation_block.1.residual_block.0.convolutional_block.0.running_var (64,) res_3_1/bn_1/beta (64,) block.3.dilation_block.1.residual_block.1.convolutional_block.0.bias (64,) res_3_1/bn_1/gamma (64,) block.3.dilation_block.1.residual_block.1.convolutional_block.0.weight (64,) res_3_1/bn_1/moving_mean (64,) block.3.dilation_block.1.residual_block.1.convolutional_block.0.running_mean (64,) res_3_1/bn_1/moving_variance (64,) block.3.dilation_block.1.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 64, 64) res_3_1/conv_0/w (64, 64, 3, 3, 3) block.3.dilation_block.1.residual_block.0.convolutional_block.3.weight (3, 3, 3, 64, 64) res_3_1/conv_1/w (64, 64, 3, 3, 3) block.3.dilation_block.1.residual_block.1.convolutional_block.3.weight (64,) res_3_2/bn_0/beta (64,) block.3.dilation_block.2.residual_block.0.convolutional_block.0.bias (64,) res_3_2/bn_0/gamma (64,) block.3.dilation_block.2.residual_block.0.convolutional_block.0.weight (64,) res_3_2/bn_0/moving_mean (64,) block.3.dilation_block.2.residual_block.0.convolutional_block.0.running_mean (64,) res_3_2/bn_0/moving_variance (64,) block.3.dilation_block.2.residual_block.0.convolutional_block.0.running_var (64,) res_3_2/bn_1/beta (64,) block.3.dilation_block.2.residual_block.1.convolutional_block.0.bias (64,) res_3_2/bn_1/gamma (64,) block.3.dilation_block.2.residual_block.1.convolutional_block.0.weight (64,) res_3_2/bn_1/moving_mean (64,) block.3.dilation_block.2.residual_block.1.convolutional_block.0.running_mean (64,) res_3_2/bn_1/moving_variance (64,) block.3.dilation_block.2.residual_block.1.convolutional_block.0.running_var (3, 3, 3, 64, 64) res_3_2/conv_0/w (64, 64, 3, 3, 3) block.3.dilation_block.2.residual_block.0.convolutional_block.3.weight (3, 3, 3, 64, 64) res_3_2/conv_1/w (64, 64, 3, 3, 3) block.3.dilation_block.2.residual_block.1.convolutional_block.3.weight State dictionary saved to /tmp/miccai_niftynet_pytorch/state_dict_pt.pth
If PyTorch is happy when loading our state dict into the model, we should be on the right track 🤞...
model.load_state_dict(state_dict_pt)
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
No incompatible keys. Yay! 🎉
Something great about PyTorch is that the model parameters are easily accessible. Let's plot some of them before and after training:
model_initial = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)
model_pretrained = model
By default, convolutional layers in PyTorch are initialized using He uniform variance scaling. These are the probability density functions (PDFs) of the kernel parameters of each convolutional layer. Note how the domain of each function change with the corresponding input size at that layer.
visualization.plot_all_parameters(model_initial)
This is what the PDFs look like after training:
visualization.plot_all_parameters(model_pretrained)
The last step is to test the PyTorch model. We will preprocess the image according to the configuration file, initialize the reader, sampler and aggregator, run the inference, and verify that results are consistent between NiftyNet and PyTorch.
[Modality0]
path_to_search = data/OASIS/
filename_contains = nii
pixdim = (1.0, 1.0, 1.0)
axcodes = (R, A, S)
[NETWORK]
name = highres3dnet
volume_padding_size = 10
whitening = True
normalisation = True
normalise_foreground_only=True
foreground_type = mean_plus
histogram_ref_file = databrain_std_hist_models_otsu.txt
cutoff = (0.001, 0.999)
[INFERENCE]
border = 2
spatial_window_size = (128, 128, 128)
We need to match the configuration used during training in order to obtain consistent results. These are the relevant contents of the downloaded configuration file:
config = ConfigParser()
config.read(config_path);
The necessary preprocessing is described in the paper, code and configuration file.
NiftyNet offers some powerful I/O tools. We will use its readers, samplers and aggregators to read, preprocess and write all the files. There are multiple demos in the NiftyNet repository that show the usage of these modules.
%%capture
input_dict = dict(
path_to_search=str(data_dir),
filename_contains='nii',
axcodes=('R', 'A', 'S'),
pixdim=(1, 1, 1),
)
data_parameters = {
'image': input_dict,
}
reader = ImageReader().initialise(data_parameters)
_, image_data_dict, _ = reader()
original_image = image_data_dict['image']
original_image.shape
(160, 256, 256, 1, 1)
Looking at the shape of our image and knowing that the reader reoriented it into RAS+ orientation, we can see that it represents $160$ sagittal slices of $256 \times 256$ pixels, with $1$ channel (monomodal) and $1$ time point. Let's see what it looks like:
plot_volume(original_image, title='Original volume')