#hide
! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
#hide
from fastbook import *
Now that you understand what deep learning is, what it's for, and how to create and deploy a model, it's time for us to go deeper! In an ideal world deep learning practitioners wouldn't have to know every detail of how things work under the hood… But as yet, we don't live in an ideal world. The truth is, to make your model really work, and work reliably, there are a lot of details you have to get right, and a lot of details that you have to check. This process requires being able to look inside your neural network as it trains, and as it makes predictions, find possible problems, and know how to fix them.
So, from here on in the book we are going to do a deep dive into the mechanics of deep learning. What is the architecture of a computer vision model, an NLP model, a tabular model, and so on? How do you create an architecture that matches the needs of your particular domain? How do you get the best possible results from the training process? How do you make things faster? What do you have to change as your datasets change?
We will start by repeating the same basic applications that we looked at in the first chapter, but we are going to do two things:
In order to do these two things, we will have to learn all of the pieces of the deep learning puzzle. This includes different types of layers, regularization methods, optimizers, how to put layers together into architectures, labeling techniques, and much more. We are not just going to dump all of these things on you, though; we will introduce them progressively as needed, to solve actual problems related to the projects we are working on.
In our very first model we learned how to classify dogs versus cats. Just a few years ago this was considered a very challenging task—but today, it's far too easy! We will not be able to show you the nuances of training models with this problem, because we get a nearly perfect result without worrying about any of the details. But it turns out that the same dataset also allows us to work on a much more challenging problem: figuring out what breed of pet is shown in each image.
In <<chapter_intro>> we presented the applications as already-solved problems. But this is not how things work in real life. We start with some dataset that we know nothing about. We then have to figure out how it is put together, how to extract the data we need from it, and what that data looks like. For the rest of this book we will be showing you how to solve these problems in practice, including all of the intermediate steps necessary to understand the data that you are working with and test your modeling as you go.
We already downloaded the Pet dataset, and we can get a path to this dataset using the same code as in <<chapter_intro>>:
from fastai.vision.all import *
path = untar_data(URLs.PETS)
Now if we are going to understand how to extract the breed of each pet from each image we're going to need to understand how this data is laid out. Such details of data layout are a vital piece of the deep learning puzzle. Data is usually provided in one of these two ways:
There are exceptions to these rules—particularly in domains such as genomics, where there can be binary database formats or even network streams—but overall the vast majority of the datasets you'll work with will use some combination of these two formats.
To see what is in our dataset we can use the ls
method:
#hide
Path.BASE_PATH = path
path.ls()
(#3) [Path('annotations'),Path('images'),Path('models')]
We can see that this dataset provides us with images and annotations directories. The website for the dataset tells us that the annotations directory contains information about where the pets are rather than what they are. In this chapter, we will be doing classification, not localization, which is to say that we care about what the pets are, not where they are. Therefore, we will ignore the annotations directory for now. So, let's have a look inside the images directory:
(path/"images").ls()
(#7394) [Path('images/great_pyrenees_173.jpg'),Path('images/wheaten_terrier_46.jpg'),Path('images/Ragdoll_262.jpg'),Path('images/german_shorthaired_3.jpg'),Path('images/american_bulldog_196.jpg'),Path('images/boxer_188.jpg'),Path('images/staffordshire_bull_terrier_173.jpg'),Path('images/basset_hound_71.jpg'),Path('images/staffordshire_bull_terrier_37.jpg'),Path('images/yorkshire_terrier_18.jpg')...]
Most functions and methods in fastai that return a collection use a class called L
. L
can be thought of as an enhanced version of the ordinary Python list
type, with added conveniences for common operations. For instance, when we display an object of this class in a notebook it appears in the format shown there. The first thing that is shown is the number of items in the collection, prefixed with a #
. You'll also see in the preceding output that the list is suffixed with an ellipsis. This means that only the first few items are displayed—which is a good thing, because we would not want more than 7,000 filenames on our screen!
By examining these filenames, we can see how they appear to be structured. Each filename contains the pet breed, and then an underscore (_
), a number, and finally the file extension. We need to create a piece of code that extracts the breed from a single Path
. Jupyter notebooks make this easy, because we can gradually build up something that works, and then use it for the entire dataset. We do have to be careful to not make too many assumptions at this point. For instance, if you look carefully you may notice that some of the pet breeds contain multiple words, so we cannot simply break at the first _
character that we find. To allow us to test our code, let's pick out one of these filenames:
fname = (path/"images").ls()[0]
The most powerful and flexible way to extract information from strings like this is to use a regular expression, also known as a regex. A regular expression is a special string, written in the regular expression language, which specifies a general rule for deciding if another string passes a test (i.e., "matches" the regular expression), and also possibly for plucking a particular part or parts out of that other string.
In this case, we need a regular expression that extracts the pet breed from the filename.
We do not have the space to give you a complete regular expression tutorial here, but there are many excellent ones online and we know that many of you will already be familiar with this wonderful tool. If you're not, that is totally fine—this is a great opportunity for you to rectify that! We find that regular expressions are one of the most useful tools in our programming toolkit, and many of our students tell us that this is one of the things they are most excited to learn about. So head over to Google and search for "regular expressions tutorial" now, and then come back here after you've had a good look around. The book's website also provides a list of our favorites.
a: Not only are regular expressions dead handy, but they also have interesting roots. They are "regular" because they were originally examples of a "regular" language, the lowest rung within the Chomsky hierarchy, a grammar classification developed by linguist Noam Chomsky, who also wrote Syntactic Structures, the pioneering work searching for the formal grammar underlying human language. This is one of the charms of computing: it may be that the hammer you reach for every day in fact came from a spaceship.
When you are writing a regular expression, the best way to start is just to try it against one example at first. Let's use the findall
method to try a regular expression against the filename of the fname
object:
re.findall(r'(.+)_\d+.jpg$', fname.name)
['great_pyrenees']
This regular expression plucks out all the characters leading up to the last underscore character, as long as the subsequence characters are numerical digits and then the JPEG file extension.
Now that we confirmed the regular expression works for the example, let's use it to label the whole dataset. fastai comes with many classes to help with labeling. For labeling with regular expressions, we can use the RegexLabeller
class. In this example we use the data block API we saw in <<chapter_production>> (in fact, we nearly always use the data block API—it's so much more flexible than the simple factory methods we saw in <<chapter_intro>>):
pets = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(seed=42),
get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'),
item_tfms=Resize(460),
batch_tfms=aug_transforms(size=224, min_scale=0.75))
dls = pets.dataloaders(path/"images")
One important piece of this DataBlock
call that we haven't seen before is in these two lines:
item_tfms=Resize(460),
batch_tfms=aug_transforms(size=224, min_scale=0.75)
These lines implement a fastai data augmentation strategy which we call presizing. Presizing is a particular way to do image augmentation that is designed to minimize data destruction while maintaining good performance.
We need our images to have the same dimensions, so that they can collate into tensors to be passed to the GPU. We also want to minimize the number of distinct augmentation computations we perform. The performance requirement suggests that we should, where possible, compose our augmentation transforms into fewer transforms (to reduce the number of computations and the number of lossy operations) and transform the images into uniform sizes (for more efficient processing on the GPU).
The challenge is that, if performed after resizing down to the augmented size, various common data augmentation transforms might introduce spurious empty zones, degrade data, or both. For instance, rotating an image by 45 degrees fills corner regions of the new bounds with emptiness, which will not teach the model anything. Many rotation and zooming operations will require interpolating to create pixels. These interpolated pixels are derived from the original image data but are still of lower quality.
To work around these challenges, presizing adopts two strategies that are shown in <
The first step, the resize, creates images large enough that they have spare margin to allow further augmentation transforms on their inner regions without creating empty zones. This transformation works by resizing to a square, using a large crop size. On the training set, the crop area is chosen randomly, and the size of the crop is selected to cover the entire width or height of the image, whichever is smaller.
In the second step, the GPU is used for all data augmentation, and all of the potentially destructive operations are done together, with a single interpolation at the end.
This picture shows the two steps:
item_tfms
, so it's applied to each individual image before it is copied to the GPU. It's used to ensure all images are the same size. On the training set, the crop area is chosen randomly. On the validation set, the center square of the image is always chosen.batch_tfms
, so it's applied to a batch all at once on the GPU, which means it's fast. On the validation set, only the resize to the final size needed for the model is done here. On the training set, the random crop and any other augmentations are done first.To implement this process in fastai you use Resize
as an item transform with a large size, and RandomResizedCrop
as a batch transform with a smaller size. RandomResizedCrop
will be added for you if you include the min_scale
parameter in your aug_transforms
function, as was done in the DataBlock
call in the previous section. Alternatively, you can use pad
or squish
instead of crop
(the default) for the initial Resize
.
<
#hide_input
#id interpolations
#caption A comparison of fastai's data augmentation strategy (left) and the traditional approach (right).
dblock1 = DataBlock(blocks=(ImageBlock(), CategoryBlock()),
get_y=parent_label,
item_tfms=Resize(460))
# Place an image in the 'images/grizzly.jpg' subfolder where this notebook is located before running this
dls1 = dblock1.dataloaders([(Path.cwd()/'images'/'grizzly.jpg')]*100, bs=8)
dls1.train.get_idxs = lambda: Inf.ones
x,y = dls1.valid.one_batch()
_,axs = subplots(1, 2)
x1 = TensorImage(x.clone())
x1 = x1.affine_coord(sz=224)
x1 = x1.rotate(draw=30, p=1.)
x1 = x1.zoom(draw=1.2, p=1.)
x1 = x1.warp(draw_x=-0.2, draw_y=0.2, p=1.)
tfms = setup_aug_tfms([Rotate(draw=30, p=1, size=224), Zoom(draw=1.2, p=1., size=224),
Warp(draw_x=-0.2, draw_y=0.2, p=1., size=224)])
x = Pipeline(tfms)(x)
#x.affine_coord(coord_tfm=coord_tfm, sz=size, mode=mode, pad_mode=pad_mode)
TensorImage(x[0]).show(ctx=axs[0])
TensorImage(x1[0]).show(ctx=axs[1]);
You can see that the image on the right is less well defined and has reflection padding artifacts in the bottom-left corner; also, the grass at the top left has disappeared entirely. We find that in practice using presizing significantly improves the accuracy of models, and often results in speedups too.
The fastai library also provides simple ways to check your data looks right before training a model, which is an extremely important step. We'll look at those next.
We can never just assume that our code is working perfectly. Writing a DataBlock
is just like writing a blueprint. You will get an error message if you have a syntax error somewhere in your code, but you have no guarantee that your template is going to work on your data source as you intend. So, before training a model you should always check your data. You can do this using the show_batch
method:
dls.show_batch(nrows=1, ncols=3)
Take a look at each image, and check that each one seems to have the correct label for that breed of pet. Often, data scientists work with data with which they are not as familiar as domain experts may be: for instance, I actually don't know what a lot of these pet breeds are. Since I am not an expert on pet breeds, I would use Google images at this point to search for a few of these breeds, and make sure the images look similar to what I see in this output.
If you made a mistake while building your DataBlock
, it is very likely you won't see it before this step. To debug this, we encourage you to use the summary
method. It will attempt to create a batch from the source you give it, with a lot of details. Also, if it fails, you will see exactly at which point the error happens, and the library will try to give you some help. For instance, one common mistake is to forget to use a Resize
transform, so you end up with pictures of different sizes and are not able to batch them. Here is what the summary would look like in that case (note that the exact text may have changed since the time of writing, but it will give you an idea):
#hide_output
pets1 = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(seed=42),
get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'))
pets1.summary(path/"images")
Setting-up type transforms pipelines Collecting items from /home/jhoward/.fastai/data/oxford-iiit-pet/images Found 7390 items 2 datasets of sizes 5912,1478 Setting up Pipeline: PILBase.create Setting up Pipeline: partial -> Categorize Building one sample Pipeline: PILBase.create starting from /home/jhoward/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_31.jpg applying PILBase.create gives PILImage mode=RGB size=500x414 Pipeline: partial -> Categorize starting from /home/jhoward/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_31.jpg applying partial gives american_pit_bull_terrier applying Categorize gives TensorCategory(13) Final sample: (PILImage mode=RGB size=500x414, TensorCategory(13)) Setting up after_item: Pipeline: ToTensor Setting up before_batch: Pipeline: Setting up after_batch: Pipeline: IntToFloatTensor Building one batch Applying item_tfms to the first sample: Pipeline: ToTensor starting from (PILImage mode=RGB size=500x414, TensorCategory(13)) applying ToTensor gives (TensorImage of size 3x414x500, TensorCategory(13)) Adding the next 3 samples No before_batch transform to apply Collating items in a batch Error! It's not possible to collate your items in a batch Could not collate the 0-th members of your tuples because got the following shapes torch.Size([3, 414, 500]),torch.Size([3, 375, 500]),torch.Size([3, 500, 281]),torch.Size([3, 203, 300])
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-11-8c0a3d421ca2> in <module> 4 splitter=RandomSplitter(seed=42), 5 get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name')) ----> 6 pets1.summary(path/"images") ~/git/fastai/fastai/data/block.py in summary(self, source, bs, show_batch, **kwargs) 182 why = _find_fail_collate(s) 183 print("Make sure all parts of your samples are tensors of the same size" if why is None else why) --> 184 raise e 185 186 if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0: ~/git/fastai/fastai/data/block.py in summary(self, source, bs, show_batch, **kwargs) 176 print("\nCollating items in a batch") 177 try: --> 178 b = dls.train.create_batch(s) 179 b = retain_types(b, s[0] if is_listy(s) else s) 180 except Exception as e: ~/git/fastai/fastai/data/load.py in create_batch(self, b) 125 def retain(self, res, b): return retain_types(res, b[0] if is_listy(b) else b) 126 def create_item(self, s): return next(self.it) if s is None else self.dataset[s] --> 127 def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b) 128 def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b) 129 def to(self, device): self.device = device ~/git/fastai/fastai/data/load.py in fa_collate(t) 44 b = t[0] 45 return (default_collate(t) if isinstance(b, _collate_types) ---> 46 else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence) 47 else default_collate(t)) 48 ~/git/fastai/fastai/data/load.py in <listcomp>(.0) 44 b = t[0] 45 return (default_collate(t) if isinstance(b, _collate_types) ---> 46 else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence) 47 else default_collate(t)) 48 ~/git/fastai/fastai/data/load.py in fa_collate(t) 43 def fa_collate(t): 44 b = t[0] ---> 45 return (default_collate(t) if isinstance(b, _collate_types) 46 else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence) 47 else default_collate(t)) ~/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch) 53 storage = elem.storage()._new_shared(numel) 54 out = elem.new(storage) ---> 55 return torch.stack(batch, 0, out=out) 56 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 57 and elem_type.__name__ != 'string_': RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 414 and 375 in dimension 2 at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/TH/generic/THTensor.cpp:612
Setting-up type transforms pipelines
Collecting items from /home/sgugger/.fastai/data/oxford-iiit-pet/images
Found 7390 items
2 datasets of sizes 5912,1478
Setting up Pipeline: PILBase.create
Setting up Pipeline: partial -> Categorize
Building one sample
Pipeline: PILBase.create
starting from
/home/sgugger/.fastai/data/oxford-iiit-pet/images/american_bulldog_83.jpg
applying PILBase.create gives
PILImage mode=RGB size=375x500
Pipeline: partial -> Categorize
starting from
/home/sgugger/.fastai/data/oxford-iiit-pet/images/american_bulldog_83.jpg
applying partial gives
american_bulldog
applying Categorize gives
TensorCategory(12)
Final sample: (PILImage mode=RGB size=375x500, TensorCategory(12))
Setting up after_item: Pipeline: ToTensor
Setting up before_batch: Pipeline:
Setting up after_batch: Pipeline: IntToFloatTensor
Building one batch
Applying item_tfms to the first sample:
Pipeline: ToTensor
starting from
(PILImage mode=RGB size=375x500, TensorCategory(12))
applying ToTensor gives
(TensorImage of size 3x500x375, TensorCategory(12))
Adding the next 3 samples
No before_batch transform to apply
Collating items in a batch
Error! It's not possible to collate your items in a batch
Could not collate the 0-th members of your tuples because got the following
shapes:
torch.Size([3, 500, 375]),torch.Size([3, 375, 500]),torch.Size([3, 333, 500]),
torch.Size([3, 375, 500])
You can see exactly how we gathered the data and split it, how we went from a filename to a sample (the tuple (image, category)), then what item transforms were applied and how it failed to collate those samples in a batch (because of the different shapes).
Once you think your data looks right, we generally recommend the next step should be using it to train a simple model. We often see people put off the training of an actual model for far too long. As a result, they don't actually find out what their baseline results look like. Perhaps your problem doesn't need lots of fancy domain-specific engineering. Or perhaps the data doesn't seem to train the model at all. These are things that you want to know as soon as possible. For this initial test, we'll use the same simple model that we used in <<chapter_intro>>:
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(2)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.551305 | 0.322132 | 0.106225 | 00:19 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.529473 | 0.312148 | 0.095399 | 00:23 |
1 | 0.330207 | 0.245883 | 0.080514 | 00:24 |
As we've briefly discussed before, the table shown when we fit a model shows us the results after each epoch of training. Remember, an epoch is one complete pass through all of the images in the data. The columns shown are the average loss over the items of the training set, the loss on the validation set, and any metrics that we requested—in this case, the error rate.
Remember that loss is whatever function we've decided to use to optimize the parameters of our model. But we haven't actually told fastai what loss function we want to use. So what is it doing? fastai will generally try to select an appropriate loss function based on what kind of data and model you are using. In this case we have image data and a categorical outcome, so fastai will default to using cross-entropy loss.
Cross-entropy loss is a loss function that is similar to the one we used in the previous chapter, but (as we'll see) has two benefits:
In order to understand how cross-entropy loss works for dependent variables with more than two categories, we first have to understand what the actual data and activations that are seen by the loss function look like.
Let's take a look at the activations of our model. To actually get a batch of real data from our DataLoaders
, we can use the one_batch
method:
x,y = dls.one_batch()
As you see, this returns the dependent and independent variables, as a mini-batch. Let's see what is actually contained in our dependent variable:
y
TensorCategory([ 0, 5, 23, 36, 5, 20, 29, 34, 33, 32, 31, 24, 12, 36, 8, 26, 30, 2, 12, 17, 7, 23, 12, 29, 21, 4, 35, 33, 0, 20, 26, 30, 3, 6, 36, 2, 17, 32, 11, 6, 3, 30, 5, 26, 26, 29, 7, 36, 31, 26, 26, 8, 13, 30, 11, 12, 36, 31, 34, 20, 15, 8, 8, 23], device='cuda:5')
Our batch size is 64, so we have 64 rows in this tensor. Each row is a single integer between 0 and 36, representing our 37 possible pet breeds. We can view the predictions (that is, the activations of the final layer of our neural network) using Learner.get_preds
. This function either takes a dataset index (0 for train and 1 for valid) or an iterator of batches. Thus, we can pass it a simple list with our batch to get our predictions. It returns predictions and targets by default, but since we already have the targets, we can effectively ignore them by assigning to the special variable _
:
preds,_ = learn.get_preds(dl=[(x,y)])
preds[0]
tensor([9.9911e-01, 5.0433e-05, 3.7515e-07, 8.8590e-07, 8.1794e-05, 1.8991e-05, 9.9280e-06, 5.4656e-07, 6.7920e-06, 2.3486e-04, 3.7872e-04, 2.0796e-05, 4.0443e-07, 1.6933e-07, 2.0502e-07, 3.1354e-08, 9.4115e-08, 2.9782e-06, 2.0243e-07, 8.5262e-08, 1.0900e-07, 1.0175e-07, 4.4780e-09, 1.4285e-07, 1.0718e-07, 8.1411e-07, 3.6618e-07, 4.0950e-07, 3.8525e-08, 2.3660e-07, 5.3747e-08, 2.5448e-07, 6.5860e-08, 8.0937e-05, 2.7464e-07, 5.6760e-07, 1.5462e-08])
The actual predictions are 37 probabilities between 0 and 1, which add up to 1 in total:
len(preds[0]),preds[0].sum()
(37, tensor(1.0000))
To transform the activations of our model into predictions like this, we used something called the softmax activation function.
In our classification model, we use the softmax activation function in the final layer to ensure that the activations are all between 0 and 1, and that they sum to 1.
Softmax is similar to the sigmoid function, which we saw earlier. As a reminder sigmoid looks like this:
plot_function(torch.sigmoid, min=-4,max=4)
We can apply this function to a single column of activations from a neural network, and get back a column of numbers between 0 and 1, so it's a very useful activation function for our final layer.
Now think about what happens if we want to have more categories in our target (such as our 37 pet breeds). That means we'll need more activations than just a single column: we need an activation per category. We can create, for instance, a neural net that predicts 3s and 7s that returns two activations, one for each class—this will be a good first step toward creating the more general approach. Let's just use some random numbers with a standard deviation of 2 (so we multiply randn
by 2) for this example, assuming we have 6 images and 2 possible categories (where the first column represents 3s and the second is 7s):
#hide
torch.random.manual_seed(42);
acts = torch.randn((6,2))*2
acts
tensor([[ 0.6734, 0.2576], [ 0.4689, 0.4607], [-2.2457, -0.3727], [ 4.4164, -1.2760], [ 0.9233, 0.5347], [ 1.0698, 1.6187]])
We can't just take the sigmoid of this directly, since we don't get rows that add to 1 (i.e., we want the probability of being a 3 plus the probability of being a 7 to add up to 1):
acts.sigmoid()
tensor([[0.6623, 0.5641], [0.6151, 0.6132], [0.0957, 0.4079], [0.9881, 0.2182], [0.7157, 0.6306], [0.7446, 0.8346]])
In <<chapter_mnist_basics>>, our neural net created a single activation per image, which we passed through the sigmoid
function. That single activation represented the model's confidence that the input was a 3. Binary problems are a special case of classification problems, because the target can be treated as a single boolean value, as we did in mnist_loss
. But binary problems can also be thought of in the context of the more general group of classifiers with any number of categories: in this case, we happen to have two categories. As we saw in the bear classifier, our neural net will return one activation per category.
So in the binary case, what do those activations really indicate? A single pair of activations simply indicates the relative confidence of the input being a 3 versus being a 7. The overall values, whether they are both high, or both low, don't matter—all that matters is which is higher, and by how much.
We would expect that since this is just another way of representing the same problem, that we would be able to use sigmoid
directly on the two-activation version of our neural net. And indeed we can! We can just take the difference between the neural net activations, because that reflects how much more sure we are of the input being a 3 than a 7, and then take the sigmoid of that:
(acts[:,0]-acts[:,1]).sigmoid()
tensor([0.6025, 0.5021, 0.1332, 0.9966, 0.5959, 0.3661])
The second column (the probability of it being a 7) will then just be that value subtracted from 1. Now, we need a way to do all this that also works for more than two columns. It turns out that this function, called softmax
, is exactly that:
def softmax(x): return exp(x) / exp(x).sum(dim=1, keepdim=True)
jargon: Exponential function (exp): Literally defined as
e**x
, wheree
is a special number approximately equal to 2.718. It is the inverse of the natural logarithm function. Note thatexp
is always positive, and it increases very rapidly!
Let's check that softmax
returns the same values as sigmoid
for the first column, and those values subtracted from 1 for the second column:
sm_acts = torch.softmax(acts, dim=1)
sm_acts
tensor([[0.6025, 0.3975], [0.5021, 0.4979], [0.1332, 0.8668], [0.9966, 0.0034], [0.5959, 0.4041], [0.3661, 0.6339]])
softmax
is the multi-category equivalent of sigmoid
—we have to use it any time we have more than two categories and the probabilities of the categories must add to 1, and we often use it even when there are just two categories, just to make things a bit more consistent. We could create other functions that have the properties that all activations are between 0 and 1, and sum to 1; however, no other function has the same relationship to the sigmoid function, which we've seen is smooth and symmetric. Also, we'll see shortly that the softmax function works well hand-in-hand with the loss function we will look at in the next section.
If we have three output activations, such as in our bear classifier, calculating softmax for a single bear image would then look like something like <<bear_softmax>>.