NiftyNet is "an open source convolutional neural networks platform for medical image analysis and image-guided therapy" built on top of TensorFlow. Due to its available implementations of successful architectures, patch-based sampling and straightforward configuration, it has become a popular choice to get started with deep learning in medical imaging.

PyTorch is "an open source deep learning platform that provides a seamless path from research prototyping to production deployment". It is low-level enough to offer a lot of control over what is going on under the hood during training, and its dynamic computational graph allows for easy debugging. Being a generic deep learning framework, it is not tailored to the needs of the medical imaging field, although its popularity in this field is increasing rapidly.

One can extend a NiftyNet application, but it is not straightforward without being familiar with the framework and fluent in TensorFlow 1.X. Therefore it can be convenient to implement applications in PyTorch using NiftyNet models and functionalities. In particular, combining both frameworks allows for fast architecture experimentation and transfer learning.

So why not use both? In this tutorial we will port the parameters of a model trained on NiftyNet to a PyTorch model and compare the results of running an inference using both frameworks.

Although NiftyNet supports different applications, it is mostly used for medical image segmentation.

***Image segmentation using deep learning*** were the 5 most common words in all full paper titles from *both* MICCAI 2018 and MIDL 2019, which shows the interest of the medical imaging community in the topic.

"Image segmentation using deep learning", guess this is the hottest topic in MIDL #MIDL2019 @midl_conference pic.twitter.com/64smdMjnxY

— Hua Ma (@forever_pippo) July 8, 2019

Just like #miccai2018! pic.twitter.com/3ZTHxj9iPT

— Julia Schnabel (@ja_schnabel) July 8, 2019

*On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task*.

The authors used NiftyNet to implement and train a model based on this architecture to perform brain parcellation using $T_1$-weighted MR images from the ADNI dataset. They achieved competitive segmentation performance compared with state-of-the-art architectures such as DeepMedic or U-Net.

This figure from the paper shows a parcellation produced by HighRes3DNet:

The code of the architecture is on NiftyNet GitHub repository. The authors have uploaded the parameters and configuration file to the Model Zoo.

After reading the paper and the code, it is relatively straightforward to implement the same architecture using PyTorch.

All the code is hosted in a GitHub repository:
`fepegar/miccai-educational-challenge-2019`

.

The latest release can also be found on the Zenodo repository under this DOI: 10.5281/zenodo.3352316.

If you have a Google account, the best way to run this notebook seamlessly is using Google Colab. You will need to click on "Open in playground", at the top left:

You will also get a couple of warnings that you can safely ignore. The first one warns about this notebook not being authored by Google and the second one asks for confirmation to reset all runtimes. These are valid points, but will not affect this tutorial.

Please report any issues on GitHub and I will fix them. You can also drop me an email if you have any questions or comments.

To write this notebook I used Ubuntu 18.04 installed on an Alienware 13 R3 laptop, which includes a 6-GB GeForce GTX 1060 NVIDIA GPU. I am using CUDA 9.0.

Inference using PyTorch took 5725 MB of GPU memory. TensorFlow usually takes as much as possible beforehand.

To run this notebook locally, I recommend downloading the repository and creating a `conda`

environment:

```
git clone https://github.com/fepegar/miccai-educational-challenge-2019.git
cd miccai-educational-challenge-2019
conda create -n mec2019 python=3.6 -y
conda activate mec2019 && conda install jupyterlab -y && jupyter lab
```

`True`

. You might need to run the volume visualization cells individually after running the whole notebook. This feature is experimental and therefore disabled by default.

In [0]:

```
interactive_plots = False
```

In [0]:

```
%%capture --no-stderr
# This might take about 30 seconds
!rm -rf NiftyNet && git clone https://github.com/NifTK/NiftyNet --depth 1
!cd NiftyNet && git checkout df0f86733357fdc92bbc191c8fec0dcf49aa5499 && cd ..
!pip install -r NiftyNet/requirements-gpu.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/requirements.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/tf2pt.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/utils.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/highresnet_mapping.py
!curl -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl
!pip install -r requirements.txt
!pip install --upgrade numpy
!pip install ipywidgets
import sys
sys.path.insert(0, 'NiftyNet')
```

In [3]:

```
%matplotlib inline
%config InlineBackend.figure_format='retina'
import os
import tempfile
from pathlib import Path
from configparser import ConfigParser
import numpy as np
import pandas as pd
try:
# Fancy rendering of Pandas tables
import google.colab.data_table
%load_ext google.colab.data_table
print("We are on Google Colab")
except ModuleNotFoundError:
print("We are not on Google Colab")
pd.set_option('display.max_colwidth', -1) # do not truncate strings when displaying data frames
pd.set_option('display.max_rows', None) # show all rows
import torch
from highresnet import HighRes3DNet
```

We are on Google Colab

In [0]:

```
%%capture
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
import tf2pt
import utils
import visualization
import highresnet_mapping
if interactive_plots: # for Colab or Jupyter
plot_volume = visualization.plot_volume_interactive
else: # for HTML, GitHub or nbviewer
plot_volume = visualization.plot_volume
from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_grid_v2 import GridSampler
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.layer.pad import PadLayer
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.histogram_normalisation import HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import MeanVarNormalisationLayer
```

`net_download`

to get all we need from the Model Zoo entry corresponding to brain parcellation using HighRes3DNet:

In [0]:

```
%%capture
%run NiftyNet/net_download.py highres3dnet_brain_parcellation_model_zoo
```

In [6]:

```
niftynet_dir = Path('~/niftynet').expanduser()
utils.list_files(niftynet_dir)
```

There are three directories under `~/niftynet`

:

`extensions`

is a Python package that contains the [configuration file].(https://niftynet.readthedocs.io/en/dev/config_spec.html)`models`

contains the landmarks for histogram standardization (an MRI preprocessing step) and the network parameters.`data`

contains an OASIS MRI sample that can be used to test the model.

`nn`

for NiftyNet, `tf`

for TensorFlow and `pt`

for PyTorch.

In [0]:

```
models_dir = niftynet_dir / 'models'
zoo_entry = 'highres3dnet_brain_parcellation'
input_checkpoint_tf_name = 'model.ckpt-33000'
input_checkpoint_tf_path = models_dir / zoo_entry / 'models' / input_checkpoint_tf_name
data_dir = niftynet_dir / 'data' / 'OASIS'
config_path = niftynet_dir / 'extensions' / zoo_entry / 'highres3dnet_config_eval.ini'
histogram_landmarks_path = models_dir / zoo_entry / 'databrain_std_hist_models_otsu.txt'
tempdir = Path(tempfile.gettempdir()) / 'miccai_niftynet_pytorch'
tempdir.mkdir(exist_ok=True)
output_csv_tf_path = tempdir / 'variables_tf.csv'
output_state_dict_tf_path = tempdir / 'state_dict_tf.pth'
output_state_dict_pt_path = tempdir / 'state_dict_pt.pth'
prediction_pt_dir = tempdir / 'prediction'
prediction_pt_dir.mkdir(exist_ok=True)
color_table_path = 'GIFNiftyNet.ctbl'
```

There are two modules that are relevant for this section in the
repository.
`tf2pt`

contains generic functions that can be used to transform any TensorFlow model to PyTorch.
`highresnet_mapping`

contains custom functions that are specific for HighRes3DNet.

Let's see what variables are stored in the checkpoint.

These are filtered out by
`highresnet_mapping.is_not_valid()`

for clarity:

- Variables used by the Adam optimizer during training
- Variables with no shape. They won't help much.
- Variables containing
`biased`

or`ExponentialMovingAverage`

. I have experimented with them and the results using these variables turned out to be different to the ones generated by NiftyNet.

`tf2pt.checkpoint_tf_to_state_dict_tf()`

reading the corresponding TensorFlow docs and Stack Overflow answers.

In [8]:

```
# %%capture
tf2pt.checkpoint_tf_to_state_dict_tf(
input_checkpoint_tf_path=input_checkpoint_tf_path,
output_csv_tf_path=output_csv_tf_path,
output_state_dict_tf_path=output_state_dict_tf_path,
filter_out_function=highresnet_mapping.is_not_valid,
replace_string='HighRes3DNet/',
)
data_frame_tf = pd.read_csv(output_csv_tf_path)
state_dict_tf = torch.load(output_state_dict_tf_path)
```

In [9]:

```
data_frame_tf
```

Out[9]:

Unnamed: 0 | name | shape | |
---|---|---|---|

0 | 0 | conv_0_bn_relu/bn_/beta | 16 |

1 | 1 | conv_0_bn_relu/bn_/gamma | 16 |

2 | 2 | conv_0_bn_relu/bn_/moving_mean | 16 |

3 | 3 | conv_0_bn_relu/bn_/moving_variance | 16 |

4 | 4 | conv_0_bn_relu/conv_/w | 3, 3, 3, 1, 16 |

5 | 5 | conv_1_bn_relu/bn_/beta | 80 |

6 | 6 | conv_1_bn_relu/bn_/gamma | 80 |

7 | 7 | conv_1_bn_relu/bn_/moving_mean | 80 |

8 | 8 | conv_1_bn_relu/bn_/moving_variance | 80 |

9 | 9 | conv_1_bn_relu/conv_/w | 1, 1, 1, 64, 80 |

10 | 10 | conv_2_bn/bn_/beta | 160 |

11 | 11 | conv_2_bn/bn_/gamma | 160 |

12 | 12 | conv_2_bn/bn_/moving_mean | 160 |

13 | 13 | conv_2_bn/bn_/moving_variance | 160 |

14 | 14 | conv_2_bn/conv_/w | 1, 1, 1, 80, 160 |

15 | 15 | res_1_0/bn_0/beta | 16 |

16 | 16 | res_1_0/bn_0/gamma | 16 |

17 | 17 | res_1_0/bn_0/moving_mean | 16 |

18 | 18 | res_1_0/bn_0/moving_variance | 16 |

19 | 19 | res_1_0/bn_1/beta | 16 |

20 | 20 | res_1_0/bn_1/gamma | 16 |

21 | 21 | res_1_0/bn_1/moving_mean | 16 |

22 | 22 | res_1_0/bn_1/moving_variance | 16 |

23 | 23 | res_1_0/conv_0/w | 3, 3, 3, 16, 16 |

24 | 24 | res_1_0/conv_1/w | 3, 3, 3, 16, 16 |

25 | 25 | res_1_1/bn_0/beta | 16 |

26 | 26 | res_1_1/bn_0/gamma | 16 |

27 | 27 | res_1_1/bn_0/moving_mean | 16 |

28 | 28 | res_1_1/bn_0/moving_variance | 16 |

29 | 29 | res_1_1/bn_1/beta | 16 |

... | ... | ... | ... |

75 | 75 | res_3_0/bn_0/beta | 32 |

76 | 76 | res_3_0/bn_0/gamma | 32 |

77 | 77 | res_3_0/bn_0/moving_mean | 32 |

78 | 78 | res_3_0/bn_0/moving_variance | 32 |

79 | 79 | res_3_0/bn_1/beta | 64 |

80 | 80 | res_3_0/bn_1/gamma | 64 |

81 | 81 | res_3_0/bn_1/moving_mean | 64 |

82 | 82 | res_3_0/bn_1/moving_variance | 64 |

83 | 83 | res_3_0/conv_0/w | 3, 3, 3, 32, 64 |

84 | 84 | res_3_0/conv_1/w | 3, 3, 3, 64, 64 |

85 | 85 | res_3_1/bn_0/beta | 64 |

86 | 86 | res_3_1/bn_0/gamma | 64 |

87 | 87 | res_3_1/bn_0/moving_mean | 64 |

88 | 88 | res_3_1/bn_0/moving_variance | 64 |

89 | 89 | res_3_1/bn_1/beta | 64 |

90 | 90 | res_3_1/bn_1/gamma | 64 |

91 | 91 | res_3_1/bn_1/moving_mean | 64 |

92 | 92 | res_3_1/bn_1/moving_variance | 64 |

93 | 93 | res_3_1/conv_0/w | 3, 3, 3, 64, 64 |

94 | 94 | res_3_1/conv_1/w | 3, 3, 3, 64, 64 |

95 | 95 | res_3_2/bn_0/beta | 64 |

96 | 96 | res_3_2/bn_0/gamma | 64 |

97 | 97 | res_3_2/bn_0/moving_mean | 64 |

98 | 98 | res_3_2/bn_0/moving_variance | 64 |

99 | 99 | res_3_2/bn_1/beta | 64 |

100 | 100 | res_3_2/bn_1/gamma | 64 |

101 | 101 | res_3_2/bn_1/moving_mean | 64 |

102 | 102 | res_3_2/bn_1/moving_variance | 64 |

103 | 103 | res_3_2/conv_0/w | 3, 3, 3, 64, 64 |

104 | 104 | res_3_2/conv_1/w | 3, 3, 3, 64, 64 |

105 rows × 3 columns

`conv_/w`

, are stored with shape representing the three spatial dimensions, the input channels and the output channels: $(Depth, Height, Width, Channels_{in}, Channels_{out})$. Calling the spatial dimensions *depth*, *height* and *width* does not make a lot of sense when dealing with 3D medical images, but we will keep this computer vision terminology as it is consistent with the documentation of both PyTorch and TensorFlow.

There are three blocks with increasing kernel dilation composed of three residual blocks each, which have two convolutional layers inside. That's $3 \times 3 \times 2 = 18$ layers. The other three convolutional layers are the initial convolution before the first residual block, a convolution before dropout and a convolution to expand the channels to the number of output classes.

Apparently, *all* the convolutional layers have an associated batch normalization layer, which differs from the figure in the paper. That makes 21 convolutional layers and 21 batch normalization layers whose parameters must be transferred.

*moving mean* $\mathrm{E}[x]$, *variance* $\mathrm{Var}[x]$ and the affine transformation parameters $\gamma$ (scale or *weight*) and $\beta$ (shift or *bias*):

\begin{align}
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\end{align}

In [0]:

```
num_input_modalities = 1
num_classes = 160
model = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)
```

Let's see what the variable names created by PyTorch are:

In [11]:

```
state_dict_pt = model.state_dict()
rows = []
for name, parameters in state_dict_pt.items():
if 'num_batches_tracked' in name: # filter out for clarity
continue
shape = ', '.join(str(n) for n in parameters.shape)
row = {'name': name, 'shape': shape}
rows.append(row)
df_pt = pd.DataFrame.from_dict(rows)
df_pt
```

Out[11]:

name | shape | |
---|---|---|

0 | block.0.convolutional_block.1.weight | 16, 1, 3, 3, 3 |

1 | block.0.convolutional_block.2.weight | 16 |

2 | block.0.convolutional_block.2.bias | 16 |

3 | block.0.convolutional_block.2.running_mean | 16 |

4 | block.0.convolutional_block.2.running_var | 16 |

5 | block.1.dilation_block.0.residual_block.0.conv... | 16 |

6 | block.1.dilation_block.0.residual_block.0.conv... | 16 |

7 | block.1.dilation_block.0.residual_block.0.conv... | 16 |

8 | block.1.dilation_block.0.residual_block.0.conv... | 16 |

9 | block.1.dilation_block.0.residual_block.0.conv... | 16, 16, 3, 3, 3 |

10 | block.1.dilation_block.0.residual_block.1.conv... | 16 |

11 | block.1.dilation_block.0.residual_block.1.conv... | 16 |

12 | block.1.dilation_block.0.residual_block.1.conv... | 16 |

13 | block.1.dilation_block.0.residual_block.1.conv... | 16 |

14 | block.1.dilation_block.0.residual_block.1.conv... | 16, 16, 3, 3, 3 |

15 | block.1.dilation_block.1.residual_block.0.conv... | 16 |

16 | block.1.dilation_block.1.residual_block.0.conv... | 16 |

17 | block.1.dilation_block.1.residual_block.0.conv... | 16 |

18 | block.1.dilation_block.1.residual_block.0.conv... | 16 |

19 | block.1.dilation_block.1.residual_block.0.conv... | 16, 16, 3, 3, 3 |

20 | block.1.dilation_block.1.residual_block.1.conv... | 16 |

21 | block.1.dilation_block.1.residual_block.1.conv... | 16 |

22 | block.1.dilation_block.1.residual_block.1.conv... | 16 |

23 | block.1.dilation_block.1.residual_block.1.conv... | 16 |

24 | block.1.dilation_block.1.residual_block.1.conv... | 16, 16, 3, 3, 3 |

25 | block.1.dilation_block.2.residual_block.0.conv... | 16 |

26 | block.1.dilation_block.2.residual_block.0.conv... | 16 |

27 | block.1.dilation_block.2.residual_block.0.conv... | 16 |

28 | block.1.dilation_block.2.residual_block.0.conv... | 16 |

29 | block.1.dilation_block.2.residual_block.0.conv... | 16, 16, 3, 3, 3 |

... | ... | ... |

75 | block.3.dilation_block.1.residual_block.0.conv... | 64 |

76 | block.3.dilation_block.1.residual_block.0.conv... | 64 |

77 | block.3.dilation_block.1.residual_block.0.conv... | 64 |

78 | block.3.dilation_block.1.residual_block.0.conv... | 64 |

79 | block.3.dilation_block.1.residual_block.0.conv... | 64, 64, 3, 3, 3 |

80 | block.3.dilation_block.1.residual_block.1.conv... | 64 |

81 | block.3.dilation_block.1.residual_block.1.conv... | 64 |

82 | block.3.dilation_block.1.residual_block.1.conv... | 64 |

83 | block.3.dilation_block.1.residual_block.1.conv... | 64 |

84 | block.3.dilation_block.1.residual_block.1.conv... | 64, 64, 3, 3, 3 |

85 | block.3.dilation_block.2.residual_block.0.conv... | 64 |

86 | block.3.dilation_block.2.residual_block.0.conv... | 64 |

87 | block.3.dilation_block.2.residual_block.0.conv... | 64 |

88 | block.3.dilation_block.2.residual_block.0.conv... | 64 |

89 | block.3.dilation_block.2.residual_block.0.conv... | 64, 64, 3, 3, 3 |

90 | block.3.dilation_block.2.residual_block.1.conv... | 64 |

91 | block.3.dilation_block.2.residual_block.1.conv... | 64 |

92 | block.3.dilation_block.2.residual_block.1.conv... | 64 |

93 | block.3.dilation_block.2.residual_block.1.conv... | 64 |

94 | block.3.dilation_block.2.residual_block.1.conv... | 64, 64, 3, 3, 3 |

95 | block.4.convolutional_block.0.weight | 80, 64, 1, 1, 1 |

96 | block.4.convolutional_block.1.weight | 80 |

97 | block.4.convolutional_block.1.bias | 80 |

98 | block.4.convolutional_block.1.running_mean | 80 |

99 | block.4.convolutional_block.1.running_var | 80 |

100 | block.6.convolutional_block.0.weight | 160, 80, 1, 1, 1 |

101 | block.6.convolutional_block.1.weight | 160 |

102 | block.6.convolutional_block.1.bias | 160 |

103 | block.6.convolutional_block.1.running_mean | 160 |

104 | block.6.convolutional_block.1.running_var | 160 |

105 rows × 2 columns

`moving_mean`

and `moving_variance`

are called `running_mean`

and `running_var`

in PyTorch. Also, $\gamma$ and $\beta$ are called `weight`

and `bias`

.

`tf2pt.tf2pt()`

receives a TensorFlow-like variable and returns the corresponding PyTorch-like variable.

In [12]:

```
for name_tf, tensor_tf in state_dict_tf.items():
shape_tf = tuple(tensor_tf.shape)
print(f'{str(shape_tf):18}', name_tf)
# Convert TensorFlow name to PyTorch name
mapping_function = highresnet_mapping.tf2pt_name
name_pt, tensor_pt = tf2pt.tf2pt(name_tf, tensor_tf, mapping_function)
shape_pt = tuple(state_dict_pt[name_pt].shape)
print(f'{str(shape_pt):18}', name_pt)
# Sanity check
if sum(shape_tf) != sum(shape_pt):
raise ValueError
# Add weights to PyTorch state dictionary
state_dict_pt[name_pt] = tensor_pt
print()
torch.save(state_dict_pt, output_state_dict_pt_path)
print('State dictionary saved to', output_state_dict_pt_path)
```

If PyTorch is happy when loading our state dict into the model, we should be on the right track 🤞...

In [13]:

```
model.load_state_dict(state_dict_pt)
```

Out[13]:

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

No incompatible keys. Yay! 🎉

In [0]:

```
model_initial = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)
model_pretrained = model
```

In [15]:

```
visualization.plot_all_parameters(model_initial)
```

This is what the PDFs look like after training:

In [16]:

```
visualization.plot_all_parameters(model_pretrained)
```

```
[Modality0]
path_to_search = data/OASIS/
filename_contains = nii
pixdim = (1.0, 1.0, 1.0)
axcodes = (R, A, S)
[NETWORK]
name = highres3dnet
volume_padding_size = 10
whitening = True
normalisation = True
normalise_foreground_only=True
foreground_type = mean_plus
histogram_ref_file = databrain_std_hist_models_otsu.txt
cutoff = (0.001, 0.999)
[INFERENCE]
border = 2
spatial_window_size = (128, 128, 128)
```

In [0]:

```
config = ConfigParser()
config.read(config_path);
```

The necessary preprocessing is described in the paper, code and configuration file.

NiftyNet offers some powerful I/O tools. We will use its readers, samplers and aggregators to read, preprocess and write all the files. There are multiple demos in the NiftyNet repository that show the usage of these modules.

In [0]:

```
%%capture
input_dict = dict(
path_to_search=str(data_dir),
filename_contains='nii',
axcodes=('R', 'A', 'S'),
pixdim=(1, 1, 1),
)
data_parameters = {
'image': input_dict,
}
reader = ImageReader().initialise(data_parameters)
```

In [19]:

```
_, image_data_dict, _ = reader()
original_image = image_data_dict['image']
original_image.shape
```

Out[19]:

(160, 256, 256, 1, 1)

In [20]:

```
plot_volume(original_image, title='Original volume')
```