from fastai.basic_train import *
from fastai.gen_doc.nbdoc import *
from fastai.vision import *
from fastai.callbacks import *
The fastai library structures its training process around the Learner
class, whose object binds together a PyTorch model, a dataset, an optimizer, and a loss function; the entire learner object then will allow us to launch training.
basic_train
defines this Learner
class, along with the wrapper around the PyTorch optimizer that the library uses. It defines the basic training loop that is used each time you call the fit
method (or one of its variants) in fastai. This training loop is very bare-bones and has very few lines of codes; you can customize it by supplying an optional Callback
argument to the fit
method.
callback
defines the Callback
class and the CallbackHandler
class that is responsible for the communication between the training loop and the Callback
's methods. The CallbackHandler
maintains a state dictionary able to provide each Callback
object all the information of the training loop it belongs to, making any imaginable tweaks of the training loop within your reach.
callbacks
implements each predefined Callback
class of the fastai library in separate modules. Some deal with scheduling the hyperparameters, like callbacks.one_cycle
, callbacks.lr_finder
and callback.general_sched
. Others allow special kinds of training like callbacks.fp16
(mixed precision) and callbacks.rnn
. The Recorder
and callbacks.hooks
are useful to save some internal data generated in the training loop.
train
then uses these callbacks to implement useful helper functions. Lastly, metrics
contains all the functions and classes you might want to use to evaluate your training results; simpler metrics are implemented as functions while more complicated ones as subclasses of Callback
. For more details on implementing metrics as Callback
, please refer to creating your own metrics.
We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each. We'll use a small subset of the classic MNIST dataset containing the images of just 3's and 7's for the purpose of demo and documentation here. To minimize the boilerplate in our docs, we've defined a function to grab the data from URLs.MNIST_SAMPLE
. The function will automatically download the dataset and unzips it if necessary, which we then will use to create an ImageDataBunch
object
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
We can create minimal CNNs using simple_cnn
(see models
for details on creating models):
model = simple_cnn((3,16,16,2))
Object of the Learner
class plays a central role in training models; they needs to know, at the very minimum, what data to train with and what exact kind of model to train.
learn = Learner(data, model)
learn.fit(1)
epoch | train_loss | valid_loss |
---|---|---|
1 | 0.124981 | 0.097195 |
To see how our training is going, we can request that it reports various kinds of metrics
after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set.
learn.metrics=[accuracy]
learn.fit(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.081563 | 0.062798 | 0.976938 |
You can use callback
s to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method.
cb = OneCycleScheduler(learn, lr_max=0.01)
learn.fit(1, callbacks=cb)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.055955 | 0.045469 | 0.984298 |
The Recorder
callback is automatically added for you, and you can use it to see what happened in your training, e.g.:
learn.recorder.plot_lr(show_moms=True)
Many of the callbacks can be used more easily by taking advantage of the Learner
extensions in train
. For instance, instead of creating OneCycleScheduler manually as above, you can simply call Learner.fit_one_cycle
:
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.040535 | 0.035062 | 0.986752 |
Note that if you're training a model for one of our supported applications, there's a lot of help available to you in the application modules:
For instance, let's use create_cnn
(from vision
) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!).
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.163659 | 0.112767 | 0.958783 |