from fastai.gen_doc.nbdoc import *
from import *
from fastai.text import *
from fastai.callbacks import *
from fastai.basic_train import *
from fastai.train import *
from fastai import callbacks
fastai's training loop is highly extensible, with a rich callback system. See the callback
docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.
Every callback that is passed to Learner
with the callback_fns
parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance ActivationStats
will appear as learn.activation_stats
(assuming your object is named learn
¶Use Leslie Smith's learning rate finder to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
learn = simple_learner()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In this example, a learning rate around 2e-2 seems like the right fit.
lr = 2e-2
¶Train with Leslie Smith's 1cycle annealing method. Let's train our simple learner using the one cycle policy.
learn.fit_one_cycle(3, lr)
epoch | train_loss | valid_loss | accuracy |
1 | 0.111205 | 0.056460 | 0.979882 |
2 | 0.040632 | 0.023650 | 0.987733 |
3 | 0.021217 | 0.020044 | 0.991659 |
The learning rate and the momentum were changed during the epochs as follows (more info on the dedicated documentation page).
¶Data augmentation using the method from mixup: Beyond Empirical Risk Minimization. It is very simple to add mixup in fastai :
learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()
learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])
epoch | train_loss | valid_loss | accuracy | error_rate |
1 | 0.125326 | 0.103473 | 0.963690 | 0.036310 |
2 | 0.077392 | 0.059223 | 0.977920 | 0.022080 |
3 | 0.065756 | 0.081031 | 0.969578 | 0.030422 |
You can then read the csv.
epoch | train_loss | valid_loss | accuracy | error_rate | |
0 | 1 | 0.125326 | 0.103473 | 0.963690 | 0.036310 |
1 | 2 | 0.077392 | 0.059223 | 0.977920 | 0.022080 |
2 | 3 | 0.065756 | 0.081031 | 0.969578 | 0.030422 |
¶Create your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule.
def fit_odd_shedule(learn, lr, mom):
n = len(
phases = [TrainingPhase(n, lr, mom, lr_anneal=annealing_cos), TrainingPhase(n*2, lr, mom, lr_anneal=annealing_poly(2))]
sched = GeneralScheduler(learn, phases)
total_epochs = 3
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
fit_odd_shedule(learn, 1e-3, 0.9)
epoch | train_loss | valid_loss | accuracy |
1 | 0.178648 | 0.161728 | 0.944553 |
2 | 0.142739 | 0.132620 | 0.957802 |
3 | 0.135239 | 0.129183 | 0.960255 |
¶Use fp16 to take advantage of tensor cores on recent NVIDIA GPUs for a 200% or more speedup.
¶Convenient wrapper for registering and automatically deregistering PyTorch hooks. Also contains pre-defined hook callback: ActivationStats
¶Callback taking care of all the tweaks to train an RNN.
¶Stop training if the loss reaches NaN.
¶Stop training if a given metric/validation loss doesn't improve.
¶Save the model at every epoch, or the best model for a given metric/validation loss.
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy')])
epoch | train_loss | valid_loss | accuracy |
1 | 0.652836 | 0.629737 | 0.613346 |
2 | 0.546113 | 0.517567 | 0.902355 |
3 | 0.495621 | 0.489153 | 0.916585 |
!ls ~/.fastai/data/mnist_sample/models
bestmodel_1.pth bestmodel_2.pth bestmodel_3.pth
¶Reduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor.
¶GPU and general RAM profiling callback
and basic_train
¶Clips gradient during training.