Build your understanding of PyTorch's ConvTranspose1d layer using interactive visualisations
Sometimes we want to see inputs and outputs of PyTorch layers to build an intuition of what they do. If I've read the docs and put a few tensors through the layer while checking the inputs and outputs shapes, generally that's enough.
But sometimes there's weird parameters that I can't get my head around or I just want to see it working, so building interactive widgets helps me grow my understanding.
So in this post I'll show you how I built an interactive widget to explore PyTorch's
ConvTranspose1d, while explaining a bit about the layer itself. We'll use Anacondas's HoloViz tools (Holoviews, Panel and Bokeh) for the plotting and interactivity.
The end goal is to have a interactive plot for interacting with
ConvTranspose1d parameters and seeing the output like this tweet.
Before learning about Transposed Convolutions, you're best learning about Convolutions first. CS231n is a great resource for learning about them.
As you may know, Convolutions are often used to efficiently reduce a dimensions of the input in neural networks. In the case of image classification tasks, they are used to efficiently reduce an input image to a single class score.
Transposed Convolutions are useful when you want to grow your network in a certain dimension. For example, say you have a image segmentation task, in which you want a class prediction per pixel, you can use strided Convolutions to reduce the dimensions and then grow the dimensions back to their original sizel with Transposed Convolutions. This is done in U-net style architectures.
PyTorch has implemented
ConvTranspose1d such that if it has the same input parameters as
Conv1d and if you pass a tensor through both, the output tensor will be the same shape as the input tensor (provided you set
#collapse-hide import torch import torch.nn as nn from panel.interact import interact from panel import widgets import panel as pn from IPython.display import display import holoviews as hv from holoviews import opts import numpy as np hv.extension('bokeh', logo=False)