#hide
! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
#hide
from fastai.vision.all import *
from fastbook import *
matplotlib.rc('image', cmap='Greys')
In <<chapter_mnist_basics>> we learned how to create a neural network recognizing images. We were able to achieve a bit over 98% accuracy at distinguishing 3s from 7s—but we also saw that fastai's built-in classes were able to get close to 100%. Let's start trying to close the gap.
In this chapter, we will begin by digging into what convolutions are and building a CNN from scratch. We will then study a range of techniques to improve training stability and learn all the tweaks the library usually applies for us to get great results.
One of the most powerful tools that machine learning practitioners have at their disposal is feature engineering. A feature is a transformation of the data which is designed to make it easier to model. For instance, the add_datepart
function that we used for our tabular dataset preprocessing in <<chapter_tabular>> added date features to the Bulldozers dataset. What kinds of features might we be able to create from images?
jargon: Feature engineering: Creating new transformations of the input data in order to make it easier to model.
In the context of an image, a feature is a visually distinctive attribute. For example, the number 7 is characterized by a horizontal edge near the top of the digit, and a top-right to bottom-left diagonal edge underneath that. On the other hand, the number 3 is characterized by a diagonal edge in one direction at the top left and bottom right of the digit, the opposite diagonal at the bottom left and top right, horizontal edges at the middle, top, and bottom, and so forth. So what if we could extract information about where the edges occur in each image, and then use that information as our features, instead of raw pixels?
It turns out that finding the edges in an image is a very common task in computer vision, and is surprisingly straightforward. To do it, we use something called a convolution. A convolution requires nothing more than multiplication, and addition—two operations that are responsible for the vast majority of work that we will see in every single deep learning model in this book!
A convolution applies a kernel across an image. A kernel is a little matrix, such as the 3×3 matrix in the top right of <<basic_conv>>.
The 7×7 grid to the left is the image we're going to apply the kernel to. The convolution operation multiplies each element of the kernel by each element of a 3×3 block of the image. The results of these multiplications are then added together. The diagram in <<basic_conv>> shows an example of applying a kernel to a single location in the image, the 3×3 block around cell 18.
Let's do this with code. First, we create a little 3×3 matrix like so:
top_edge = tensor([[-1,-1,-1],
[ 0, 0, 0],
[ 1, 1, 1]]).float()
We're going to call this our kernel (because that's what fancy computer vision researchers call these). And we'll need an image, of course:
path = untar_data(URLs.MNIST_SAMPLE)
#hide
Path.BASE_PATH = path
im3 = Image.open(path/'train'/'3'/'12.png')
show_image(im3);
Now we're going to take the top 3×3-pixel square of our image, and multiply each of those values by each item in our kernel. Then we'll add them up, like so:
im3_t = tensor(im3)
im3_t[0:3,0:3] * top_edge
tensor([[-0., -0., -0.], [0., 0., 0.], [0., 0., 0.]])
(im3_t[0:3,0:3] * top_edge).sum()
tensor(0.)
Not very interesting so far—all the pixels in the top-left corner are white. But let's pick a couple of more interesting spots:
#hide_output
df = pd.DataFrame(im3_t[:10,:20])
df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 0 | 0 | 0 | 12 | 99 | 91 | 142 | 155 | 246 | 182 | 155 | 155 | 155 | 155 | 131 | 52 | 0 | 0 | 0 | 0 |
6 | 0 | 0 | 0 | 138 | 254 | 254 | 254 | 254 | 254 | 254 | 254 | 254 | 254 | 254 | 254 | 252 | 210 | 122 | 33 | 0 |
7 | 0 | 0 | 0 | 220 | 254 | 254 | 254 | 235 | 189 | 189 | 189 | 189 | 150 | 189 | 205 | 254 | 254 | 254 | 75 | 0 |
8 | 0 | 0 | 0 | 35 | 74 | 35 | 35 | 25 | 0 | 0 | 0 | 0 | 0 | 0 | 13 | 224 | 254 | 254 | 153 | 0 |
9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 90 | 254 | 254 | 247 | 53 | 0 |
There's a top edge at cell 5,8. Let's repeat our calculation there:
(im3_t[4:7,6:9] * top_edge).sum()
tensor(762.)
There's a right edge at cell 8,18. What does that give us?:
(im3_t[7:10,17:20] * top_edge).sum()
tensor(-29.)
As you can see, this little calculation is returning a high number where the 3×3-pixel square represents a top edge (i.e., where there are low values at the top of the square, and high values immediately underneath). That's because the -1
values in our kernel have little impact in that case, but the 1
values have a lot.
Let's look a tiny bit at the math. The filter will take any window of size 3×3 in our images, and if we name the pixel values like this:
$$\begin{matrix} a1 & a2 & a3 \\ a4 & a5 & a6 \\ a7 & a8 & a9 \end{matrix}$$it will return $-a1-a2-a3+a7+a8+a9$. If we are in a part of the image where $a1$, $a2$, and $a3$ add up to the same as $a7$, $a8$, and $a9$, then the terms will cancel each other out and we will get 0. However, if $a7$ is greater than $a1$, $a8$ is greater than $a2$, and $a9$ is greater than $a3$, we will get a bigger number as a result. So this filter detects horizontal edges—more precisely, edges where we go from bright parts of the image at the top to darker parts at the bottom.
Changing our filter to have the row of 1
s at the top and the -1
s at the bottom would detect horizontal edges that go from dark to light. Putting the 1
s and -1
s in columns versus rows would give us filters that detect vertical edges. Each set of weights will produce a different kind of outcome.
Let's create a function to do this for one location, and check it matches our result from before:
def apply_kernel(row, col, kernel):
return (im3_t[row-1:row+2,col-1:col+2] * kernel).sum()
apply_kernel(5,7,top_edge)
tensor(762.)
But note that we can't apply it to the corner (e.g., location 0,0), since there isn't a complete 3×3 square there.
We can map apply_kernel()
across the coordinate grid. That is, we'll be taking our 3×3 kernel, and applying it to each 3×3 section of our image. For instance, <<nopad_conv>> shows the positions a 3×3 kernel can be applied to in the first row of a 5×5 image.
To get a grid of coordinates we can use a nested list comprehension, like so:
[[(i,j) for j in range(1,5)] for i in range(1,5)]
[[(1, 1), (1, 2), (1, 3), (1, 4)], [(2, 1), (2, 2), (2, 3), (2, 4)], [(3, 1), (3, 2), (3, 3), (3, 4)], [(4, 1), (4, 2), (4, 3), (4, 4)]]
note: Nested List Comprehensions: Nested list comprehensions are used a lot in Python, so if you haven't seen them before, take a few minutes to make sure you understand what's happening here, and experiment with writing your own nested list comprehensions.
Here's the result of applying our kernel over a coordinate grid:
rng = range(1,27)
top_edge3 = tensor([[apply_kernel(i,j,top_edge) for j in rng] for i in rng])
show_image(top_edge3);
Looking good! Our top edges are black, and bottom edges are white (since they are the opposite of top edges). Now that our image contains negative numbers too, matplotlib
has automatically changed our colors so that white is the smallest number in the image, black the highest, and zeros appear as gray.
We can try the same thing for left edges:
left_edge = tensor([[-1,1,0],
[-1,1,0],
[-1,1,0]]).float()
left_edge3 = tensor([[apply_kernel(i,j,left_edge) for j in rng] for i in rng])
show_image(left_edge3);
As we mentioned before, a convolution is the operation of applying such a kernel over a grid in this way. In the paper "A Guide to Convolution Arithmetic for Deep Learning" there are many great diagrams showing how image kernels can be applied. Here's an example from the paper showing (at the bottom) a light blue 4×4 image, with a dark blue 3×3 kernel being applied, creating a 2×2 green output activation map at the top.
Look at the shape of the result. If the original image has a height of h
and a width of w
, how many 3×3 windows can we find? As you can see from the example, there are h-2
by w-2
windows, so the image we get has a result as a height of h-2
and a width of w-2
.
We won't implement this convolution function from scratch, but use PyTorch's implementation instead (it is way faster than anything we could do in Python).
Convolution is such an important and widely used operation that PyTorch has it built in. It's called F.conv2d
(recall that F
is a fastai import from torch.nn.functional
, as recommended by PyTorch). The PyTorch docs tell us that it includes these parameters:
(minibatch, in_channels, iH, iW)
(out_channels, in_channels, kH, kW)
Here iH,iW
is the height and width of the image (i.e., 28,28
), and kH,kW
is the height and width of our kernel (3,3
). But apparently PyTorch is expecting rank-4 tensors for both these arguments, whereas currently we only have rank-2 tensors (i.e., matrices, or arrays with two axes).
The reason for these extra axes is that PyTorch has a few tricks up its sleeve. The first trick is that PyTorch can apply a convolution to multiple images at the same time. That means we can call it on every item in a batch at once!
The second trick is that PyTorch can apply multiple kernels at the same time. So let's create the diagonal-edge kernels too, and then stack all four of our edge kernels into a single tensor:
diag1_edge = tensor([[ 0,-1, 1],
[-1, 1, 0],
[ 1, 0, 0]]).float()
diag2_edge = tensor([[ 1,-1, 0],
[ 0, 1,-1],
[ 0, 0, 1]]).float()
edge_kernels = torch.stack([left_edge, top_edge, diag1_edge, diag2_edge])
edge_kernels.shape
torch.Size([4, 3, 3])
To test this, we'll need a DataLoader
and a sample mini-batch. Let's use the data block API:
mnist = DataBlock((ImageBlock(cls=PILImageBW), CategoryBlock),
get_items=get_image_files,
splitter=GrandparentSplitter(),
get_y=parent_label)
dls = mnist.dataloaders(path)
xb,yb = first(dls.valid)
xb.shape
torch.Size([64, 1, 28, 28])
By default, fastai puts data on the GPU when using data blocks. Let's move it to the CPU for our examples:
xb,yb = to_cpu(xb),to_cpu(yb)
One batch contains 64 images, each of 1 channel, with 28×28 pixels. F.conv2d
can handle multichannel (i.e., color) images too. A channel is a single basic color in an image—for regular full-color images there are three channels, red, green, and blue. PyTorch represents an image as a rank-3 tensor, with dimensions [channels, rows, columns]
.
We'll see how to handle more than one channel later in this chapter. Kernels passed to F.conv2d
need to be rank-4 tensors: [channels_in, features_out, rows, columns]
. edge_kernels
is currently missing one of these. We need to tell PyTorch that the number of input channels in the kernel is one, which we can do by inserting an axis of size one (this is known as a unit axis) in the first location, where the PyTorch docs show in_channels
is expected. To insert a unit axis into a tensor, we use the unsqueeze
method:
edge_kernels.shape,edge_kernels.unsqueeze(1).shape
(torch.Size([4, 3, 3]), torch.Size([4, 1, 3, 3]))
This is now the correct shape for edge_kernels
. Let's pass this all to conv2d
:
edge_kernels = edge_kernels.unsqueeze(1)
batch_features = F.conv2d(xb, edge_kernels)
batch_features.shape
torch.Size([64, 4, 26, 26])
The output shape shows we gave 64 images in the mini-batch, 4 kernels, and 26×26 edge maps (we started with 28×28 images, but lost one pixel from each side as discussed earlier). We can see we get the same results as when we did this manually:
show_image(batch_features[0,0]);
The most important trick that PyTorch has up its sleeve is that it can use the GPU to do all this work in parallel—that is, applying multiple kernels, to multiple images, across multiple channels. Doing lots of work in parallel is critical to getting GPUs to work efficiently; if we did each of these operations one at a time, we'd often run hundreds of times slower (and if we used our manual convolution loop from the previous section, we'd be millions of times slower!). Therefore, to become a strong deep learning practitioner, one skill to practice is giving your GPU plenty of work to do at a time.
It would be nice to not lose those two pixels on each axis. The way we do that is to add padding, which is simply additional pixels added around the outside of our image. Most commonly, pixels of zeros are added.
With appropriate padding, we can ensure that the output activation map is the same size as the original image, which can make things a lot simpler when we construct our architectures. <<pad_conv>> shows how adding padding allows us to apply the kernels in the image corners.
With a 5×5 input, 4×4 kernel, and 2 pixels of padding, we end up with a 6×6 activation map, as we can see in <<four_by_five_conv>>.
If we add a kernel of size ks
by ks
(with ks
an odd number), the necessary padding on each side to keep the same shape is ks//2
. An even number for ks
would require a different amount of padding on the top/bottom and left/right, but in practice we almost never use an even filter size.
So far, when we have applied the kernel to the grid, we have moved it one pixel over at a time. But we can jump further; for instance, we could move over two pixels after each kernel application, as in <<three_by_five_conv>>. This is known as a stride-2 convolution. The most common kernel size in practice is 3×3, and the most common padding is 1. As you'll see, stride-2 convolutions are useful for decreasing the size of our outputs, and stride-1 convolutions are useful for adding layers without changing the output size.
In an image of size h
by w
, using a padding of 1 and a stride of 2 will give us a result of size (h+1)//2
by (w+1)//2
. The general formula for each dimension is (n + 2*pad - ks)//stride + 1
, where pad
is the padding, ks
, the size of our kernel, and stride
is the stride.
Let's now take a look at how the pixel values of the result of our convolutions are computed.
To explain the math behind convolutions, fast.ai student Matt Kleinsmith came up with the very clever idea of showing CNNs from different viewpoints. In fact, it's so clever, and so helpful, we're going to show it here too!
Here's our 3×3 pixel image, with each pixel labeled with a letter:
And here's our kernel, with each weight labeled with a Greek letter:
Since the filter fits in the image four times, we have four results:
<<apply_kernel>> shows how we applied the kernel to each section of the image to yield each result.
The equation view is in <<eq_view>>.
Notice that the bias term, b, is the same for each section of the image. You can consider the bias as part of the filter, just like the weights (α, β, γ, δ) are part of the filter.
Here's an interesting insight—a convolution can be represented as a special kind of matrix multiplication, as illustrated in <<conv_matmul>>. The weight matrix is just like the ones from traditional neural networks. However, this weight matrix has two special properties:
The zeros correspond to the pixels that the filter can't touch. Each row of the weight matrix corresponds to one application of the filter.
Now that we understand what a convolution is, let's use them to build a neural net.
There is no reason to believe that some particular edge filters are the most useful kernels for image recognition. Furthermore, we've seen that in later layers convolutional kernels become complex transformations of features from lower levels, but we don't have a good idea of how to manually construct these.
Instead, it would be best to learn the values of the kernels. We already know how to do this—SGD! In effect, the model will learn the features that are useful for classification.
When we use convolutions instead of (or in addition to) regular linear layers we create a convolutional neural network (CNN).
Let's go back to the basic neural network we had in <<chapter_mnist_basics>>. It was defined like this:
simple_net = nn.Sequential(
nn.Linear(28*28,30),
nn.ReLU(),
nn.Linear(30,1)
)
We can view a model's definition:
simple_net
Sequential( (0): Linear(in_features=784, out_features=30, bias=True) (1): ReLU() (2): Linear(in_features=30, out_features=1, bias=True) )
We now want to create a similar architecture to this linear model, but using convolutional layers instead of linear. nn.Conv2d
is the module equivalent of F.conv2d
. It's more convenient than F.conv2d
when creating an architecture, because it creates the weight matrix for us automatically when we instantiate it.
Here's a possible architecture:
broken_cnn = sequential(
nn.Conv2d(1,30, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(30,1, kernel_size=3, padding=1)
)
One thing to note here is that we didn't need to specify 28×28 as the input size. That's because a linear layer needs a weight in the weight matrix for every pixel, so it needs to know how many pixels there are, but a convolution is applied over each pixel automatically. The weights only depend on the number of input and output channels and the kernel size, as we saw in the previous section.
Think about what the output shape is going to be, then let's try it and see:
broken_cnn(xb).shape
torch.Size([64, 1, 28, 28])
This is not something we can use to do classification, since we need a single output activation per image, not a 28×28 map of activations. One way to deal with this is to use enough stride-2 convolutions such that the final layer is size 1. That is, after one stride-2 convolution the size will be 14×14, after two it will be 7×7, then 4×4, 2×2, and finally size 1.
Let's try that now. First, we'll define a function with the basic parameters we'll use in each convolution:
def conv(ni, nf, ks=3, act=True):
res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
if act: res = nn.Sequential(res, nn.ReLU())
return res
important: Refactoring: Refactoring parts of your neural networks like this makes it much less likely you'll get errors due to inconsistencies in your architectures, and makes it more obvious to the reader which parts of your layers are actually changing.
When we use a stride-2 convolution, we often increase the number of features at the same time. This is because we're decreasing the number of activations in the activation map by a factor of 4; we don't want to decrease the capacity of a layer by too much at a time.
jargon: channels and features: These two terms are largely used interchangeably, and refer to the size of the second axis of a weight matrix, which is, the number of activations per grid cell after a convolution. Features is never used to refer to the input data, but channels can refer to either the input data (generally channels are colors) or activations inside the network.
Here is how we can build a simple CNN:
simple_cnn = sequential(
conv(1 ,4), #14x14
conv(4 ,8), #7x7
conv(8 ,16), #4x4
conv(16,32), #2x2
conv(32,2, act=False), #1x1
Flatten(),
)
j: I like to add comments like the ones here after each convolution to show how large the activation map will be after each layer. These comments assume that the input size is 28*28
Now the network outputs two activations, which map to the two possible levels in our labels:
simple_cnn(xb).shape
torch.Size([64, 2])
We can now create our Learner
:
learn = Learner(dls, simple_cnn, loss_func=F.cross_entropy, metrics=accuracy)
To see exactly what's going on in the model, we can use summary
:
learn.summary()
Sequential (Input shape: ['64 x 1 x 28 x 28']) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Conv2d 64 x 4 x 14 x 14 40 True ________________________________________________________________ ReLU 64 x 4 x 14 x 14 0 False ________________________________________________________________ Conv2d 64 x 8 x 7 x 7 296 True ________________________________________________________________ ReLU 64 x 8 x 7 x 7 0 False ________________________________________________________________ Conv2d 64 x 16 x 4 x 4 1,168 True ________________________________________________________________ ReLU 64 x 16 x 4 x 4 0 False ________________________________________________________________ Conv2d 64 x 32 x 2 x 2 4,640 True ________________________________________________________________ ReLU 64 x 32 x 2 x 2 0 False ________________________________________________________________ Conv2d 64 x 2 x 1 x 1 578 True ________________________________________________________________ Flatten 64 x 2 0 False ________________________________________________________________ Total params: 6,722 Total trainable params: 6,722 Total non-trainable params: 0 Optimizer used: <function Adam at 0x7fbc9c258cb0> Loss function: <function cross_entropy at 0x7fbca9ba0170> Callbacks: - TrainEvalCallback - Recorder - ProgressCallback
Note that the output of the final Conv2d
layer is 64x2x1x1
. We need to remove those extra 1x1
axes; that's what Flatten
does. It's basically the same as PyTorch's squeeze
method, but as a module.
Let's see if this trains! Since this is a deeper network than we've built from scratch before, we'll use a lower learning rate and more epochs:
learn.fit_one_cycle(2, 0.01)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.072684 | 0.045110 | 0.990186 | 00:05 |
1 | 0.022580 | 0.030775 | 0.990186 | 00:05 |
Success! It's getting closer to the resnet18
result we had, although it's not quite there yet, and it's taking more epochs, and we're needing to use a lower learning rate. We still have a few more tricks to learn, but we're getting closer and closer to being able to create a modern CNN from scratch.
We can see from the summary that we have an input of size 64x1x28x28
. The axes are batch,channel,height,width
. This is often represented as NCHW
(where N
refers to batch size). Tensorflow, on the other hand, uses NHWC
axis order. The first layer is:
m = learn.model[0]
m
Sequential( (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU() )
So we have 1 input channel, 4 output channels, and a 3×3 kernel. Let's check the weights of the first convolution:
m[0].weight.shape
torch.Size([4, 1, 3, 3])
The summary shows we have 40 parameters, and 4*1*3*3
is 36. What are the other four parameters? Let's see what the bias contains:
m[0].bias.shape
torch.Size([4])
We can now use this information to clarify our statement in the previous section: "When we use a stride-2 convolution, we often increase the number of features because we're decreasing the number of activations in the activation map by a factor of 4; we don't want to decrease the capacity of a layer by too much at a time."
There is one bias for each channel. (Sometimes channels are called features or filters when they are not input channels.) The output shape is 64x4x14x14
, and this will therefore become the input shape to the next layer. The next layer, according to summary
, has 296 parameters. Let's ignore the batch axis to keep things simple. So for each of 14*14=196
locations we are multiplying 296-8=288
weights (ignoring the bias for simplicity), so that's 196*288=56_448
multiplications at this layer. The next layer will have 7*7*(1168-16)=56_448
multiplications.
What happened here is that our stride-2 convolution halved the grid size from 14x14
to 7x7
, and we doubled the number of filters from 8 to 16, resulting in no overall change in the amount of computation. If we left the number of channels the same in each stride-2 layer, the amount of computation being done in the net would get less and less as it gets deeper. But we know that the deeper layers have to compute semantically rich features (such as eyes or fur), so we wouldn't expect that doing less computation would make sense.
Another way to think of this is based on receptive fields.
The receptive field is the area of an image that is involved in the calculation of a layer. On the book's website, you'll find an Excel spreadsheet called conv-example.xlsx that shows the calculation of two stride-2 convolutional layers using an MNIST digit. Each layer has a single kernel. <
Here, the cell with the green border is the cell we clicked on, and the blue highlighted cells are its precedents—that is, the cells used to calculate its value. These cells are the corresponding 3×3 area of cells from the input layer (on the left), and the cells from the filter (on the right). Let's now click trace precedents again, to see what cells are used to calculate these inputs. <
In this example, we have just two convolutional layers, each of stride 2, so this is now tracing right back to the input image. We can see that a 7×7 area of cells in the input layer is used to calculate the single green cell in the Conv2 layer. This 7×7 area is the receptive field in the input of the green activation in Conv2. We can also see that a second filter kernel is needed now, since we have two layers.
As you see from this example, the deeper we are in the network (specifically, the more stride-2 convs we have before a layer), the larger the receptive field for an activation in that layer. A large receptive field means that a large amount of the input image is used to calculate each activation in that layer is. We now know that in the deeper layers of the network we have semantically rich features, corresponding to larger receptive fields. Therefore, we'd expect that we'd need more weights for each of our features to handle this increasing complexity. This is another way of saying the same thing we mentioned in the previous section: when we introduce a stride-2 conv in our network, we should also increase the number of channels.
When writing this particular chapter, we had a lot of questions we needed answers for, to be able to explain CNNs to you as best we could. Believe it or not, we found most of the answers on Twitter. We're going to take a quick break to talk to you about that now, before we move on to color images.
We are not, to say the least, big users of social networks in general. But our goal in writing this book is to help you become the best deep learning practitioner you can, and we would be remiss not to mention how important Twitter has been in our own deep learning journeys.
You see, there's another part of Twitter, far away from Donald Trump and the Kardashians, which is the part of Twitter where deep learning researchers and practitioners talk shop every day. As we were writing this section, Jeremy wanted to double-check that what we were saying about stride-2 convolutions was accurate, so he asked on Twitter:
A few minutes later, this answer popped up:
Christian Szegedy is the first author of Inception, the 2014 ImageNet winner and source of many key insights used in modern neural networks. Two hours later, this appeared:
Do you recognize that name? You saw it in <<chapter_production>>, when we were talking about the Turing Award winners who established the foundations of deep learning today!
Jeremy also asked on Twitter for help checking our description of label smoothing in <<chapter_sizing_and_tta>> was accurate, and got a response again from directly from Christian Szegedy (label smoothing was originally introduced in the Inception paper):
Many of the top people in deep learning today are Twitter regulars, and are very open about interacting with the wider community. One good way to get started is to look at a list of Jeremy's recent Twitter likes, or Sylvain's. That way, you can see a list of Twitter users that we think have interesting and useful things to say.
Twitter is the main way we both stay up to date with interesting papers, software releases, and other deep learning news. For making connections with the deep learning community, we recommend getting involved both in the fast.ai forums and on Twitter.
That said, let's get back to the meat of this chapter. Up until now, we have only shown you examples of pictures in black and white, with one value per pixel. In practice, most colored images have three values per pixel to define their color. We'll look at working with color images next.
A colour picture is a rank-3 tensor:
im = image2tensor(Image.open(image_bear()))
im.shape
torch.Size([3, 1000, 846])
show_image(im);
The first axis contains the channels, red, green, and blue:
_,axs = subplots(1,3)
for bear,ax,color in zip(im,axs,('Reds','Greens','Blues')):
show_image(255-bear, ax=ax, cmap=color)