#!/usr/bin/env python # coding: utf-8 # # Combining the power of PyTorch and NiftyNet # ## Contents # 1. [Introduction](#Introduction) # 2. [Running the notebook](#Running-the-notebook) # 3. [Setup](#Setup) # 4. [Transferring parameters from NiftyNet to PyTorch](#Transferring-parameters-from-NiftyNet-to-PyTorch) # 5. [Testing the model](#Testing-the-model) # 6. [Future work](#Future-work) # 7. [Conclusion](#Conclusion) # 8. [Acknowledgments](#Acknowledgments) # # # ## Introduction # NiftyNet is "[an open source convolutional neural networks platform for medical image analysis and image-guided therapy](https://niftynet.io/)" built on top of [TensorFlow](https://www.tensorflow.org/). Due to its available implementations of successful architectures, patch-based sampling and straightforward configuration, it has become a [popular choice](https://github.com/NifTK/NiftyNet/network/members) to get started with deep learning in medical imaging. # # PyTorch is "[an open source deep learning platform that provides a seamless path from research prototyping to production deployment](https://pytorch.org/)". It is low-level enough to offer a lot of control over what is going on under the hood during training, and its [dynamic computational graph](https://medium.com/intuitionmachine/pytorch-dynamic-computational-graphs-and-modular-deep-learning-7e7f89f18d1) allows for easy debugging. Being a generic deep learning framework, it is not tailored to the needs of the medical imaging field, although its popularity in this field is increasing rapidly. # # One can [extend a NiftyNet application](https://niftynet.readthedocs.io/en/dev/extending_app.html), but it is not straightforward without being familiar with the framework and fluent in TensorFlow 1.X. Therefore it can be convenient to implement applications in PyTorch using NiftyNet models and functionalities. In particular, combining both frameworks allows for fast architecture experimentation and transfer learning. # # So why not use [both](https://www.youtube.com/watch?v=vqgSO8_cRio&feature=youtu.be&t=5)? In this tutorial we will port the parameters of a model trained on NiftyNet to a PyTorch model and compare the results of running an inference using both frameworks. # ### Image segmentation # Although NiftyNet supports different applications, it is mostly used for medical image segmentation. # # ***Image segmentation using deep learning*** were the 5 most common words in all full paper titles from *both* [MICCAI 2018](https://www.miccai2018.org/en/) and [MIDL 2019](https://2019.midl.io/), which shows the interest of the medical imaging community in the topic. #

"Image segmentation using deep learning", guess this is the hottest topic in MIDL #MIDL2019 @midl_conference pic.twitter.com/64smdMjnxY

— Hua Ma (@forever_pippo) July 8, 2019
#

Just like #miccai2018! pic.twitter.com/3ZTHxj9iPT

— Julia Schnabel (@ja_schnabel) July 8, 2019
# ### HighRes3DNet # drawing # HighRes3DNet is a residual convolutional neural network architecture designed to have a large receptive field and preserve a high resolution using a relatively small number of parameters. It was presented in 2017 by Li et al. at IPMI: [*On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task*](https://arxiv.org/abs/1707.01992). # HighRes3DNet # The authors used NiftyNet to implement and train a model based on this architecture to perform [brain parcellation](https://ieeexplore.ieee.org/document/7086081?arnumber=7086081) using $T_1$-weighted MR images from the [ADNI dataset](http://adni.loni.usc.edu/). They achieved competitive segmentation performance compared with state-of-the-art architectures such as [DeepMedic](https://biomedia.doc.ic.ac.uk/software/deepmedic/) or [U-Net](https://arxiv.org/abs/1606.06650). # # This figure from the paper shows a parcellation produced by HighRes3DNet: # # Input MRI # Output parcellation # # # # # The code of the architecture is on [NiftyNet GitHub repository](https://github.com/NifTK/NiftyNet/blob/dev/niftynet/network/highres3dnet.py). The authors have uploaded the parameters and configuration file to the [Model Zoo](https://github.com/NifTK/NiftyNetModelZoo/tree/5-reorganising-with-lfs/highres3dnet_brain_parcellation). # # After reading the paper and the code, it is relatively straightforward to [implement the same architecture using PyTorch](https://github.com/fepegar/highresnet). # ## Running the notebook # All the code is hosted in a GitHub repository: # [`fepegar/miccai-educational-challenge-2019`](https://github.com/fepegar/miccai-educational-challenge-2019). # # The latest release can also be found on the Zenodo repository under this DOI: [10.5281/zenodo.3352316](https://doi.org/10.5281/zenodo.3352316). # ### Online # If you have a Google account, the best way to run this notebook seamlessly is using [Google Colab](https://colab.research.google.com/drive/1vqDojKuC4Svb97LdoEyZQygm3jccX4hr). You will need to click on "Open in playground", at the top left: # # ![Playground mode screenshot](https://github.com/fepegar/miccai-educational-challenge-2019/raw/master/images/playground.png) # # You will also get a couple of warnings that you can safely ignore. The first one warns about this notebook not being authored by Google and the second one asks for confirmation to reset all runtimes. These are valid points, but will not affect this tutorial. # # [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1vqDojKuC4Svb97LdoEyZQygm3jccX4hr) # # --- # # Please [report any issues on GitHub](https://github.com/fepegar/miccai-educational-challenge-2019/issues/new) and I will fix them. You can also [drop me an email](mailto:fernando.perez.garcia.17@ucl.ac.uk?subject=Combining%20the%20power%20of%20PyTorch%20and%20NiftyNet) if you have any questions or comments. # # # ### Locally # To write this notebook I used Ubuntu 18.04 installed on an Alienware 13 R3 laptop, which includes a 6-GB GeForce GTX 1060 NVIDIA GPU. I am using CUDA 9.0. # # Inference using PyTorch took 5725 MB of GPU memory. TensorFlow usually takes as much as possible beforehand. # To run this notebook locally, I recommend downloading the repository and creating a [`conda`](https://docs.conda.io/en/latest/miniconda.html) environment: # # ```shell # git clone https://github.com/fepegar/miccai-educational-challenge-2019.git # cd miccai-educational-challenge-2019 # conda create -n mec2019 python=3.6 -y # conda activate mec2019 && conda install jupyterlab -y && jupyter lab # ``` # ### nbviewer # An already executed version of the notebook can be rendered using [nbviewer](https://nbviewer.jupyter.org/github/fepegar/miccai-educational-challenge-2019/blob/master/Combining_the_power_of_PyTorch_and_NiftyNet.ipynb?flush_cache=true). # # # # # ### Interactive volume plots # If you run the notebook, you can use interactive plots to navigate through the volume slices by setting this variable to `True`. You might need to run the volume visualization cells individually after running the whole notebook. This feature is experimental and therefore disabled by default. # In[ ]: interactive_plots = False # ## Setup # ### Install and import libraries # Clone NiftyNet and some custom Python libraries for this notebook. This might take one or two minutes. # In[ ]: get_ipython().run_cell_magic('capture', '--no-stderr', "# This might take about 30 seconds\n!rm -rf NiftyNet && git clone https://github.com/NifTK/NiftyNet --depth 1\n!cd NiftyNet && git checkout df0f86733357fdc92bbc191c8fec0dcf49aa5499 && cd ..\n!pip install -r NiftyNet/requirements-gpu.txt\n!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/requirements.txt\n!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/tf2pt.py\n!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/utils.py\n!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py\n!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/highresnet_mapping.py\n!curl -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl\n!pip install -r requirements.txt\n!pip install --upgrade numpy\n!pip install ipywidgets\nimport sys\nsys.path.insert(0, 'NiftyNet')\n") # In[3]: get_ipython().run_line_magic('matplotlib', 'inline') get_ipython().run_line_magic('config', "InlineBackend.figure_format='retina'") import os import tempfile from pathlib import Path from configparser import ConfigParser import numpy as np import pandas as pd try: # Fancy rendering of Pandas tables import google.colab.data_table get_ipython().run_line_magic('load_ext', 'google.colab.data_table') print("We are on Google Colab") except ModuleNotFoundError: print("We are not on Google Colab") pd.set_option('display.max_colwidth', -1) # do not truncate strings when displaying data frames pd.set_option('display.max_rows', None) # show all rows import torch from highresnet import HighRes3DNet # In[ ]: get_ipython().run_cell_magic('capture', '', "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\nfrom tensorflow.python.util import deprecation\ndeprecation._PRINT_DEPRECATION_WARNINGS = False\n\nimport tf2pt\nimport utils\nimport visualization\nimport highresnet_mapping\n\nif interactive_plots: # for Colab or Jupyter\n plot_volume = visualization.plot_volume_interactive\nelse: # for HTML, GitHub or nbviewer\n plot_volume = visualization.plot_volume\n\nfrom niftynet.io.image_reader import ImageReader\nfrom niftynet.engine.sampler_grid_v2 import GridSampler\nfrom niftynet.engine.windows_aggregator_grid import GridSamplesAggregator\nfrom niftynet.layer.pad import PadLayer\nfrom niftynet.layer.binary_masking import BinaryMaskingLayer\nfrom niftynet.layer.histogram_normalisation import HistogramNormalisationLayer\nfrom niftynet.layer.mean_variance_normalisation import MeanVarNormalisationLayer\n") # ### Download NiftyNet model and test data # We can use NiftyNet's `net_download` to get all we need from the [Model Zoo entry](https://github.com/NifTK/NiftyNetModelZoo/tree/5-reorganising-with-lfs/highres3dnet_brain_parcellation#downloading-model-zoo-files) corresponding to brain parcellation using HighRes3DNet: # In[ ]: get_ipython().run_cell_magic('capture', '', '%run NiftyNet/net_download.py highres3dnet_brain_parcellation_model_zoo\n') # In[6]: niftynet_dir = Path('~/niftynet').expanduser() utils.list_files(niftynet_dir) # There are three directories under `~/niftynet`: # 1. `extensions` is a Python package that contains the [configuration file].(https://niftynet.readthedocs.io/en/dev/config_spec.html) # 2. `models` contains the landmarks for [histogram standardization](https://ieeexplore.ieee.org/document/836373) (an MRI preprocessing step) and the network parameters. # 3. `data` contains an [OASIS](https://www.oasis-brains.org/) MRI sample that can be used to test the model. # Here are the paths to the downloaded files and to the files that will be generated by the notebook. I use `nn` for NiftyNet, `tf` for TensorFlow and `pt` for PyTorch. # In[ ]: models_dir = niftynet_dir / 'models' zoo_entry = 'highres3dnet_brain_parcellation' input_checkpoint_tf_name = 'model.ckpt-33000' input_checkpoint_tf_path = models_dir / zoo_entry / 'models' / input_checkpoint_tf_name data_dir = niftynet_dir / 'data' / 'OASIS' config_path = niftynet_dir / 'extensions' / zoo_entry / 'highres3dnet_config_eval.ini' histogram_landmarks_path = models_dir / zoo_entry / 'databrain_std_hist_models_otsu.txt' tempdir = Path(tempfile.gettempdir()) / 'miccai_niftynet_pytorch' tempdir.mkdir(exist_ok=True) output_csv_tf_path = tempdir / 'variables_tf.csv' output_state_dict_tf_path = tempdir / 'state_dict_tf.pth' output_state_dict_pt_path = tempdir / 'state_dict_pt.pth' prediction_pt_dir = tempdir / 'prediction' prediction_pt_dir.mkdir(exist_ok=True) color_table_path = 'GIFNiftyNet.ctbl' # Note that the path to the checkpoint is not a path to an existing filename, but the basename of the three checkpoint files. # ## Transferring parameters from NiftyNet to PyTorch # ### Variables in TensorFlow world # # drawing # There are two modules that are relevant for this section in the # [repository](https://github.com/fepegar/miccai-educational-challenge-2019). # [`tf2pt`](https://github.com/fepegar/miccai-educational-challenge-2019/blob/master/tf2pt.py) contains generic functions that can be used to transform any TensorFlow model to PyTorch. # [`highresnet_mapping`](https://github.com/fepegar/miccai-educational-challenge-2019/blob/master/highresnet_mapping.py) contains custom functions that are specific for HighRes3DNet. # # Let's see what variables are stored in the checkpoint. # # These are filtered out by # [`highresnet_mapping.is_not_valid()`](https://github.com/fepegar/miccai-educational-challenge-2019/blob/c96777d654ac577c0dba218038f76c2497de946a/highresnet_mapping.py#L4-L11) # for clarity: # * Variables used by the Adam optimizer during training # * Variables with no shape. They won't help much. # * Variables containing `biased` or `ExponentialMovingAverage`. I have experimented with them and the results using these variables turned out to be different to the ones generated by NiftyNet. # We will store the variables names in a data frame to list them in this notebook and the values in a Python dictionary to retrieve them later. I figured out the code in # [`tf2pt.checkpoint_tf_to_state_dict_tf()`](https://github.com/fepegar/miccai-educational-challenge-2019/blob/c96777d654ac577c0dba218038f76c2497de946a/tf2pt.py#L90-L129) # reading the corresponding [TensorFlow docs](https://www.tensorflow.org/api_docs/python/tf/train/list_variables) and [Stack Overflow answers](https://stackoverflow.com/search?q=restore+tensorflow). # In[8]: # %%capture tf2pt.checkpoint_tf_to_state_dict_tf( input_checkpoint_tf_path=input_checkpoint_tf_path, output_csv_tf_path=output_csv_tf_path, output_state_dict_tf_path=output_state_dict_tf_path, filter_out_function=highresnet_mapping.is_not_valid, replace_string='HighRes3DNet/', ) data_frame_tf = pd.read_csv(output_csv_tf_path) state_dict_tf = torch.load(output_state_dict_tf_path) # In[9]: data_frame_tf # The weight parameters associated with each convolutional layer, denoted with `conv_/w`, are stored with shape representing the three spatial dimensions, the input channels and the output channels: $(Depth, Height, Width, Channels_{in}, Channels_{out})$. Calling the spatial dimensions *depth*, *height* and *width* does not make a lot of sense when dealing with 3D medical images, but we will keep this computer vision terminology as it is consistent with the documentation of both PyTorch and TensorFlow. # The layer names and parameter shapes are coherent overall with the figure in the HighRes3DNet paper, but there is an additional $1 \times 1 \times 1$ convolutional layer with 80 output channels, which is also in the [code](https://github.com/NifTK/NiftyNet/blob/1832a516c909b67d0d9618acbd04a7642c12efca/niftynet/network/highres3dnet.py#L93). It seems to be the model with [dropout](http://jmlr.org/papers/v15/srivastava14a.html) from the paper that achieved the highest performance, so [our implementation of the architecture](https://github.com/fepegar/highresnet/blob/f434266a51924681f95b01a0f03611fbf1148db6/highresnet/highresnet.py#L82-L97) should include this layer as well. # There are three blocks with increasing kernel [dilation](https://arxiv.org/abs/1511.07122) composed of three [residual](https://arxiv.org/abs/1512.03385) blocks each, which have two convolutional layers inside. That's $3 \times 3 \times 2 = 18$ layers. The other three convolutional layers are the initial convolution before the first residual block, a convolution before dropout and a convolution to expand the channels to the number of output classes. # # Apparently, *all* the convolutional layers have an associated [batch normalization](https://arxiv.org/abs/1502.03167) layer, which differs from the figure in the paper. That makes 21 convolutional layers and 21 batch normalization layers whose parameters must be transferred. # Each batch normalization layer contains 4 parameter groups: *moving mean* $\mathrm{E}[x]$, *variance* $\mathrm{Var}[x]$ and the affine transformation parameters $\gamma$ (scale or *weight*) and $\beta$ (shift or *bias*): # \begin{align} # y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta # \end{align} # Therefore the total number of parameter groups is $21 + 21 \times 4 = 105$. The convolutional layers don't use a bias parameter, as it is not necessary when using batch norm. # ### Variables in PyTorch world # # drawing # # To match the model in the [paper]((https://arxiv.org/abs/1707.01992), we set the number of classes to 160 and enable the flag to add the dropout layer. # In[ ]: num_input_modalities = 1 num_classes = 160 model = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True) # Let's see what the variable names created by PyTorch are: # In[11]: state_dict_pt = model.state_dict() rows = [] for name, parameters in state_dict_pt.items(): if 'num_batches_tracked' in name: # filter out for clarity continue shape = ', '.join(str(n) for n in parameters.shape) row = {'name': name, 'shape': shape} rows.append(row) df_pt = pd.DataFrame.from_dict(rows) df_pt # We can see that `moving_mean` and `moving_variance` are called `running_mean` and `running_var` in PyTorch. Also, $\gamma$ and $\beta$ are called `weight` and `bias`. # The convolutional kernels have a different arrangement: $(Channels_{out}, Channels_{in}, Depth, Height, Width)$. # The names and shapes look consistent between both implementations and there are 105 lines in both lists, therefore we should be able to create a mapping between the TensorFlow and PyTorch variables. The function `tf2pt.tf2pt()` receives a TensorFlow-like variable and returns the corresponding PyTorch-like variable. # In[12]: for name_tf, tensor_tf in state_dict_tf.items(): shape_tf = tuple(tensor_tf.shape) print(f'{str(shape_tf):18}', name_tf) # Convert TensorFlow name to PyTorch name mapping_function = highresnet_mapping.tf2pt_name name_pt, tensor_pt = tf2pt.tf2pt(name_tf, tensor_tf, mapping_function) shape_pt = tuple(state_dict_pt[name_pt].shape) print(f'{str(shape_pt):18}', name_pt) # Sanity check if sum(shape_tf) != sum(shape_pt): raise ValueError # Add weights to PyTorch state dictionary state_dict_pt[name_pt] = tensor_pt print() torch.save(state_dict_pt, output_state_dict_pt_path) print('State dictionary saved to', output_state_dict_pt_path) # If PyTorch is happy when loading our state dict into the model, we should be on the right track 🤞... # In[13]: model.load_state_dict(state_dict_pt) # No incompatible keys. Yay! 🎉 # ### Plotting weights with PyTorch # Something great about PyTorch is that the model parameters are easily accessible. Let's plot some of them before and after training: # In[ ]: model_initial = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True) model_pretrained = model # By [default](https://github.com/pytorch/pytorch/blob/77353636de32a207cf0a332395f91011bc2f07fb/torch/nn/modules/conv.py#L48-L53), convolutional layers in PyTorch are initialized using [He uniform variance scaling](https://arxiv.org/abs/1502.01852). These are the probability density functions (PDFs) of the kernel parameters of each convolutional layer. Note how the domain of each function change with the corresponding input size at that layer. # In[15]: visualization.plot_all_parameters(model_initial) # This is what the PDFs look like after training: # In[16]: visualization.plot_all_parameters(model_pretrained) # ## Testing the model # The last step is to test the PyTorch model. We will preprocess the image according to the configuration file, initialize the reader, sampler and aggregator, run the inference, and verify that results are consistent between NiftyNet and PyTorch. # ### Configuration file # ```ini # [Modality0] # path_to_search = data/OASIS/ # filename_contains = nii # pixdim = (1.0, 1.0, 1.0) # axcodes = (R, A, S) # # [NETWORK] # name = highres3dnet # volume_padding_size = 10 # whitening = True # normalisation = True # normalise_foreground_only=True # foreground_type = mean_plus # histogram_ref_file = databrain_std_hist_models_otsu.txt # cutoff = (0.001, 0.999) # # [INFERENCE] # border = 2 # spatial_window_size = (128, 128, 128) # ``` # We need to match the configuration used during training in order to obtain consistent results. These are the relevant contents of the downloaded [configuration file](https://niftynet.readthedocs.io/en/dev/config_spec.html): # In[ ]: config = ConfigParser() config.read(config_path); # ### Reader # The necessary preprocessing is described in the [paper](https://arxiv.org/abs/1707.01992), [code](https://github.com/NifTK/NiftyNet/blob/61f2a8bbac1348591412c00f55d1c19b91c0367f/niftynet/application/segmentation_application.py#L95-L192) and [configuration file](https://niftynet.readthedocs.io/en/dev/config_spec.html). # # NiftyNet offers some powerful I/O tools. We will use its readers, samplers and aggregators to read, preprocess and write all the files. There are multiple [demos in the NiftyNet repository](https://github.com/NifTK/NiftyNet/tree/dev/demos/module_examples) that show the usage of these modules. # In[ ]: get_ipython().run_cell_magic('capture', '', "input_dict = dict(\n path_to_search=str(data_dir),\n filename_contains='nii',\n axcodes=('R', 'A', 'S'),\n pixdim=(1, 1, 1),\n)\ndata_parameters = {\n 'image': input_dict,\n}\nreader = ImageReader().initialise(data_parameters)\n") # In[19]: _, image_data_dict, _ = reader() original_image = image_data_dict['image'] original_image.shape # Looking at the shape of our image and knowing that the reader reoriented it into [RAS+ orientation](http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm), we can see that it represents $160$ sagittal slices of $256 \times 256$ pixels, with $1$ channel (monomodal) and $1$ time point. Let's see what it looks like: # In[20]: plot_volume(original_image, title='Original volume') # In[21]: visualization.plot_histogram(original_image, kde=False, add_labels=True, ylim=(0, 1e6)) # ### Preprocessing # We [pad the input volume](https://niftynet.readthedocs.io/en/dev/window_sizes.html#volume-padding-size) and crop the output volume to reduce the border effect introduced by the padded convolutions: # In[ ]: volume_padding_layer = PadLayer( image_name=['image'], # https://github.com/NifTK/NiftyNet/blob/61f2a8bbac1348591412c00f55d1c19b91c0367f/niftynet/layer/pad.py#L52 border=(10, 10, 10), ) # We use a masking function in order to use only the foreground voxels for normalization: # In[23]: binary_masking_func = BinaryMaskingLayer(type_str=config['NETWORK']['foreground_type']) mask = binary_masking_func(original_image) plot_volume(mask, enhance=False, title='Binary mask for preprocessing') # We use [MRI histogram standardization](https://ieeexplore.ieee.org/document/836373) trained on the training dataset for our test image. We use the mean intensity of the volume as a threshold for the mask, as the authors of the method claim that this usually gives good results. # In[ ]: hist_norm = HistogramNormalisationLayer( image_name='image', modalities=['Modality0'], model_filename=str(histogram_landmarks_path), binary_masking_func=binary_masking_func, cutoff=(0.001, 0.999), name='hist_norm_layer', ) # Finally, we force our image foreground to have zero mean and unit variance: # In[ ]: whitening = MeanVarNormalisationLayer( image_name='image', binary_masking_func=binary_masking_func) # Here is our preprocessed image: # In[26]: get_ipython().run_cell_magic('capture', '--no-display', "preprocessing_layers = [\n volume_padding_layer,\n hist_norm,\n whitening,\n]\nreader = ImageReader().initialise(data_parameters)\nreader.add_preprocessing_layers(preprocessing_layers)\n_, image_data_dict, _ = reader()\npreprocessed_image = image_data_dict['image']\nplot_volume(preprocessed_image, title='Preprocessed image')\n") # Note the small difference of intensities due to histogram standardization and the 10-voxel zero-padding. # # We can clearly see the effect of the whitening layer on the histogram: # In[27]: visualization.plot_histogram(preprocessed_image, kde=False, add_labels=True, ylim=(0, 1e6)) # ### Sampler and aggregator # As the whole image does not fit in most GPUs, we need to use a [patch-based](https://niftynet.readthedocs.io/en/dev/window_sizes.html) approach. # # drawing # We will use NiftyNet's grid sampler to get all windows from the volume (blue in the previous image) and a grid samples aggregator (red) to reconstruct the output image from the inferred windows. If you have any memory issues, try reducing the window size. # # The [window border](https://niftynet.readthedocs.io/en/dev/window_sizes.html#border) is needed to reduce the border effect in a dense prediction. # In[ ]: get_ipython().run_cell_magic('capture', '', "window_size = 128\nwindow_size = 3 * (window_size, )\nwindow_border = 2, 2, 2\nwindow_size_dict = {'image': window_size}\nbatch_size = 1\n\nsampler = GridSampler(\n reader,\n window_size_dict,\n window_border=window_border,\n)\n\nprediction_pt_dir = tempdir / 'prediction'\nprediction_pt_dir.mkdir(exist_ok=True)\naggregator = GridSamplesAggregator(\n image_reader=reader,\n window_border=window_border,\n output_path=prediction_pt_dir,\n)\n") # ### Run inference # Now, the most important part: running the parcellation! We will iterate over the windows provided by the grid sampler, pass them through the network and aggregate them to the output volume. With the default parameters, the inference will run for 27 iterations, which might take a couple of minutes: # In[29]: device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' model.to(device) model.eval() for batch_index, batch_dict in enumerate(sampler()): print(f'Running inference iteration {batch_index}...') input_tensor = tf2pt.niftynet_batch_to_torch_tensor(batch_dict).to(device) with torch.no_grad(): logits = model(input_tensor) labels = tf2pt.torch_logits_to_niftynet_labels(logits) window_dict = dict(window=labels) aggregator.decode_batch(window_dict, batch_dict['image_location']) # Release GPU memory del model del model_pretrained del input_tensor del logits torch.cuda.empty_cache() # Let's [run the inference using NiftyNet](https://github.com/NifTK/NiftyNetModelZoo/tree/5-reorganising-with-lfs/highres3dnet_brain_parcellation#generating-segmentations-for-example-data) as well, so that we can compare both results. As before, this step might take a couple of minutes: # In[30]: get_ipython().run_line_magic('run', 'NiftyNet/net_segment.py inference -c ~/niftynet/extensions/highres3dnet_brain_parcellation/highres3dnet_config_eval.ini') # ### Check results # In[ ]: get_ipython().run_cell_magic('capture', '', 'input_image = utils.get_first_array(data_dir)\nlabels_nn = utils.get_first_array(models_dir).astype(np.uint16)\nlabels_pt = utils.get_first_array(prediction_pt_dir).astype(np.uint16)\n') # #### Quantitatively # In[32]: difference = labels_nn != labels_pt num_different_voxels = np.count_nonzero(difference) print('Number of different voxels:', num_different_voxels) # Success! ✨ Both parcellations are exactly the same. # #### Qualitatively # In[33]: plot_volume( labels_nn, enhance=False, colors_path=color_table_path, title='Parcellation inferred by NiftyNet', ) # In[34]: plot_volume( labels_pt, enhance=False, colors_path=color_table_path, title='Parcellation inferred by PyTorch', ) # ## Future work # The next thing we might want to do is start from the pretrained model for a [transfer learning](http://cs231n.github.io/transfer-learning/) application. # # For example, one could use the model as a feature extractor, removing the classifier (last convolutional layer). The features extracted from an image could be used for classification tasks, such as Alzheimer's vs control patients. We could also add and train a 4-output-channel classifier for brain tissue segmentation -cerebrospinal fluid (CSF), white matter, gray matter and background-. # # We might also want to fine-tune the model so that it performs better with a smaller dataset which is different to the one that was used for training. In that case, if we want to use the same optimizer (Adam) we should also retrieve from the checkpoint files the parameters related to it and use them for training. # ## Conclusion # In this tutorial we have shown how to combine features from two deep learning frameworks, NiftyNet and PyTorch. # # We ported a model for brain parcellation from NiftyNet to PyTorch and ran an inference using a PyTorch model and NiftyNet I/O capabilities. # ## Acknowledgments # I want to thank Pritesh Mehta, Tom Varsavsky, Zach Eaton-Rosen and Oeslle Lucena for their feedback. # # If you want to cite this tutorial you can use this [Zenodo DOI](https://doi.org/10.5281/zenodo.3352316): # # [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3352316.svg)](https://doi.org/10.5281/zenodo.3352316)