from fastai.gen_doc.nbdoc import *
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
This module regroups the callbacks that track one of the metrics computed at the end of each epoch to take some decision about training. To show examples of use, we'll use our sample of MNIST and a simple cnn model.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
show_doc(callbacks.TerminateOnNaNCallback)
Sometimes, training diverges and the loss goes to nan. In that case, there's no point continuing, so this callback stops the training.
model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy])
learn.fit_one_cycle(2,1e4)
Total time: 00:04 epoch train_loss valid_loss accuracy 1 nan nan 0.504416 (00:02) 2 nan nan 0.504416 (00:02)
Using it prevents that situation to happen.
model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy], callbacks=[TerminateOnNaNCallback()])
learn.fit(2,1e4)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
Epoch/Batch (0/5): Invalid loss, terminating training.
show_doc(EarlyStoppingCallback)
class
EarlyStoppingCallback
[source]
EarlyStoppingCallback
(learn
:Learner
,monitor
:str
='val_loss'
,mode
:str
='auto'
,min_delta
:int
=0
,patience
:int
=0
) ::TrackerCallback
A TrackerCallback
that terminates training when monitored quantity stops improving.
This callback tracks the quantity in monitor
during the training of learn
. mode
can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will stop training after patience
epochs if the quantity hasn't improved by min_delta
.
model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy],
callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)])
learn.fit(50,1e-42)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.692837 | 0.692778 | 0.496565 |
2 | 0.692831 | 0.692778 | 0.496565 |
3 | 0.692877 | 0.692778 | 0.496565 |
Epoch 4: early stopping
show_doc(SaveModelCallback)
class
SaveModelCallback
[source]
SaveModelCallback
(learn
:Learner
,monitor
:str
='val_loss'
,mode
:str
='auto'
,every
:str
='improvement'
,name
:str
='bestmodel'
) ::TrackerCallback
A TrackerCallback
that saves the model when monitored quantity is best.
This callback tracks the quantity in monitor
during the training of learn
. mode
can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will save the model in name
whenever determined by every
('improvement' or 'epoch'). Loads the best model at the end of training is every='improvement'
.
show_doc(ReduceLROnPlateauCallback)
class
ReduceLROnPlateauCallback
[source]
ReduceLROnPlateauCallback
(learn
:Learner
,monitor
:str
='val_loss'
,mode
:str
='auto'
,patience
:int
=0
,factor
:float
=0.2
,min_delta
:int
=0
) ::TrackerCallback
A TrackerCallback
that reduces learning rate when a metric has stopped improving.
This callback tracks the quantity in monitor
during the training of learn
. mode
can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will reduce the learning rate by factor
after patience
epochs if the quantity hasn't improved by min_delta
.
show_doc(TrackerCallback)
class
TrackerCallback
[source]
TrackerCallback
(learn
:Learner
,monitor
:str
='val_loss'
,mode
:str
='auto'
) ::LearnerCallback
A LearnerCallback
that keeps track of the best value in monitor
.
show_doc(SaveModelCallback.on_epoch_end)
show_doc(TerminateOnNaNCallback.on_batch_end)
on_batch_end
[source]
on_batch_end
(last_loss
,epoch
,num_batch
,kwargs
:Any
)
Called at the end of the batch.
show_doc(EarlyStoppingCallback.on_train_begin)
show_doc(SaveModelCallback.on_train_end)
show_doc(ReduceLROnPlateauCallback.on_epoch_end)
show_doc(EarlyStoppingCallback.on_epoch_end)
show_doc(TerminateOnNaNCallback.on_epoch_end)
show_doc(TrackerCallback.on_train_begin)
show_doc(ReduceLROnPlateauCallback.on_train_begin)
show_doc(TrackerCallback.get_monitor_value)
get_monitor_value
[source]
get_monitor_value
()
show_doc(TerminateOnNaNCallback)