%matplotlib inline
from fastai.gen_doc.nbdoc import *
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
Learning rate finder plots lr vs loss relationship for a Learner
. The idea is to reduce the amount of guesswork on picking a good starting learning rate.
Overview:
learn.lr_find()
learn.recorder.plot()
Technical Details: (first described by Leslie Smith)
Train
Learner
over a few iterations. Start with a very lowstart_lr
and change it at each mini-batch until it reaches a very highend_lr
.Recorder
will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges.
For a more intuitive explanation, please check out Sylvain Gugger's post
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
learn = simple_learner()
First we run this command to launch the search:
show_doc(Learner.lr_find)
learn.lr_find(stop_div=False, num_it=200)
LR Finder complete, type {learner_name}.recorder.plot() to see the graph.
Then we plot the loss versus the learning rates. We're interested in finding a good order of magnitude of learning rate, so we plot with a log scale.
learn.recorder.plot()
Then, we choose a value that is approximately in the middle of the sharpest downward slope. In this case, training with 3e-2 looks like it should work well:
simple_learner().fit(2, 3e-2)
Total time: 00:03 epoch train_loss valid_loss accuracy 1 0.070224 0.039051 0.986752 (00:01) 2 0.038105 0.043696 0.985280 (00:01)
Don't just pick the minimum value from the plot!:
learn = simple_learner()
simple_learner().fit(2, 1e-0)
Total time: 00:03 epoch train_loss valid_loss accuracy 1 0.724437 0.693147 0.495584 (00:01) 2 0.693758 0.693147 0.495584 (00:01)
Picking a value before the downward slope results in slow training:
learn = simple_learner()
simple_learner().fit(2, 1e-3)
Total time: 00:03 epoch train_loss valid_loss accuracy 1 0.184354 0.168152 0.940137 (00:01) 2 0.146272 0.143661 0.946516 (00:01)
show_doc(LRFinder)
class
LRFinder
[source]
LRFinder
(learn
:Learner
,start_lr
:float
=1e-07
,end_lr
:float
=10
,num_it
:int
=100
,stop_div
:bool
=True
) ::LearnerCallback
Causes learn
to go on a mock training from start_lr
to end_lr
for num_it
iterations. Training is interrupted if the loss diverges. Weights changes are reverted after run complete.
show_doc(LRFinder.on_train_end)
on_train_end
[source]
on_train_end
(kwargs
:Any
)
Cleanup learn model weights disturbed during LRFind exploration.
show_doc(LRFinder.on_batch_end)
on_batch_end
[source]
on_batch_end
(iteration
:int
,smooth_loss
:TensorOrNumber
,kwargs
:Any
)
Determine if loss has runaway and we should stop.
show_doc(LRFinder.on_train_begin)
on_train_begin
[source]
on_train_begin
(pbar
,kwargs
:Any
)
Initialize optimizer and learner hyperparameters.
show_doc(LRFinder.on_epoch_end)