from fastai.basic_train import *
from fastai.gen_doc.nbdoc import *
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
The fastai library is structured training around a Learner
object that binds together a pytorch model, some data with an optimizer and a loss function, which then will allow us to launch training.
basic_train
contains the definition of this Learner
class along with the wrapper around pytorch optimizer that the library uses. It defines the basic training loop that is used each time you call the fit
function in fastai (or one of its variants). This training loop is kept to the minimum number of instructions, and most of its customization happens in Callback
objects.
callback
contains the definition of those, as well as the CallbackHandler
that is responsible for the communication between the training loop and the Callback
functions. It maintains a state dictionary to be able to provide to each Callback
all the informations of the training loop, easily allowing any tweaks you could think of.
In callbacks
, each Callback
is then implemented in separate modules. Some deal with scheduling the hyperparameters, like callbacks.one_cycle
, callbacks.lr_finder
or callback.general_sched
. Others allow special kind of trainings like callbacks.fp16
(mixed precision) or callbacks.rnn
. The Recorder
or callbacks.hooks
are useful to save some internal data.
train
then implements those callbacks with useful helper functions. Lastly metrics
contains all the functions you might want to call to evaluate your results.
We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each. We'll use the classic MNIST dataset for the training documentation, cut down to just 3's and 7's. To minimize the boilerplate in our docs we've defined a funcion to grab the data from URLs.MNIST_SAMPLE
which will automatically download and unzip if not already done function, then we put it in an ImageDataBunch
.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
We can create minimal simple CNNs using simple_cnn
(see models
for details on creating models):
model = simple_cnn((3,16,16,2))
The most important object for training models is Learner
, which needs to know, at minimum, what data to train with and what model to train.
learn = Learner(data, model)
That's enough to train a model, which is done using fit
. If you have a CUDA-capable GPU it will be used automatically. You have to say how many epochs to train for.
learn.fit(1)
Total time: 00:02 epoch train_loss valid_loss 1 0.141339 0.121598 (00:02)
To see how our training is going, we can request that it reports various metrics
after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set.
learn.metrics=[accuracy]
learn.fit(1)
Total time: 00:02 epoch train_loss valid_loss accuracy 1 0.109016 0.091778 0.969578 (00:02)
You can use callback
s to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method.
cb = OneCycleScheduler(learn, lr_max=0.01)
learn.fit(1, callbacks=cb)
Total time: 00:02 epoch train_loss valid_loss accuracy 1 0.091946 0.068201 0.974975 (00:02)
The Recorder
callback is automatically added for you, and you can use it to see what happened in your training, e.g.:
learn.recorder.plot_lr(show_moms=True)
Many of the callbacks can be used more easily by taking advantage of the Learner
extensions in train
. For instance, instead of creating OneCycleScheduler manually as above, you can simply call Learner.fit_one_cycle
:
learn.fit_one_cycle(1)
VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…
Total time: 00:02 epoch train loss valid loss accuracy 0 0.044362 0.045060 0.984298 (00:02)
Note that if you're training a model for one of our supported applications, there's a lot of help available to you in the application modules:
For instance, let's use create_cnn
(from vision
) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!).
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit_one_cycle(1)
VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…
Total time: 00:09 epoch train loss valid loss accuracy 0 0.093473 0.068315 0.976938 (00:09)