from fastai.gen_doc.nbdoc import *
from fastai.callback import *
from fastai.basics import *
fastai provides a powerful callback system, which is documented on the callbacks
page; look on that page if you're just looking for how to use existing callbacks. If you want to create your own, you'll need to use the classes discussed below.
A key motivation for the callback system is that additional functionality can be entirely implemented in a single callback, so that it's easily read. By using this trick, we will have different methods categorized in different callbacks where we will find clearly stated all the interventions the method makes in training. For instance in the LRFinder
callback, on top of running the fit function with exponentially growing LRs, it needs to handle some preparation and clean-up, and all this code can be in the same callback so we know exactly what it is doing and where to look if we need to change something.
In addition, it allows our fit
function to be very clean and simple, yet still easily extended. So far in implementing a number of recent papers, we haven't yet come across any situation where we had to modify our training loop source code - we've been able to use callbacks every time.
show_doc(Callback)
To create a new type of callback, you'll need to inherit from this class, and implement one or more methods as required for your purposes. Perhaps the easiest way to get started is to look at the source code for some of the pre-defined fastai callbacks. You might be surprised at how simple they are! For instance, here is the entire source code for GradientClipping
:
@dataclass
class GradientClipping(LearnerCallback):
clip:float
def on_backward_end(self, **kwargs):
if self.clip:
nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)
You generally want your custom callback constructor to take a Learner
parameter, e.g.:
@dataclass
class MyCallback(Callback):
learn:Learner
Note that this allows the callback user to just pass your callback name to callback_fns
when constructing their Learner
, since that always passes self
when constructing callbacks from callback_fns
. In addition, by passing the learner, this callback will have access to everything: e.g all the inputs/outputs as they are calculated, the losses, and also the data loaders, the optimizer, etc. At any time:
DataBunch
object to the fit function and not data.train_dl/data.valid_dl)OptimWrapper
on top of the actual optimizer) will change it inside the fit function.In any of the callbacks you can unpack in the kwargs:
n_epochs
, contains the number of epochs the training will take in totalepoch
, contains the number of the currentiteration
, contains the number of iterations done since the beginning of trainingnum_batch
, contains the number of the batch we're at in the dataloaderlast_input
, contains the last input that got through the model (eventually updated by a callback)last_target
, contains the last target that gor through the model (eventually updated by a callback)last_output
, contains the last output spitted by the model (eventually updated by a callback)last_loss
, contains the last loss computed (eventually updated by a callback)smooth_loss
, contains the smoothed version of the losslast_metrics
, contains the last validation loss and metrics computedpbar
, the progress bartrain
, flag to know if we're in training mode or notstop_training
, that will stop the training at the end of the current epoch if Truestop_epoch
, that will break the current epoch loopskip_step
, that will skip the next optimizer stepskip_zero
, that will skip the next zero gradWhen returning a dictionary with those key names, the state of the CallbackHandler
will be updated with any of those changes, so in any Callback
, you can change those values.
All of these methods are optional; your subclass can handle as many or as few as you require.
show_doc(Callback.on_train_begin)
on_train_begin
[source][test]
on_train_begin
(****kwargs
**:Any
)
No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
To initialize constants in the callback.
Here we can initiliaze anything we need. The optimizer has now been initialized. We can change any hyper-parameters by typing, for instance:
self.opt.lr = new_lr
self.opt.mom = new_mom
self.opt.wd = new_wd
self.opt.beta = new_beta
show_doc(Callback.on_epoch_begin)
on_epoch_begin
[source][test]
on_epoch_begin
(****kwargs
**:Any
)
No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
At the beginning of each epoch.
This is not technically required since we have on_train_begin
for epoch 0 and on_epoch_end
for all the other epochs,
yet it makes writing code that needs to be done at the beginning of every epoch easy and more readable.
show_doc(Callback.on_batch_begin)
on_batch_begin
[source][test]
on_batch_begin
(****kwargs
**:Any
)
No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
Set HP before the output and loss are computed.
Here is the perfect place to prepare everything before the model is called. Example: change the values of the hyperparameters (if we don't do it on_batch_end instead)
At the end of that event xb
,yb
will be set to last_input
, last_target
of the state of the CallbackHandler
.
show_doc(Callback.on_loss_begin)
on_loss_begin
[source][test]
on_loss_begin
(****kwargs
**:Any
)
No tests found for on_loss_begin
. To contribute a test please refer to this guide and this discussion.
Called after forward pass but before loss has been computed.
Here is the place to run some code that needs to be executed after the output has been computed but before the loss computation. Example: putting the output back in FP32 when training in mixed precision.
At the end of that event the output will be set to last_output
of the state of the CallbackHandler
.
show_doc(Callback.on_backward_begin)
on_backward_begin
[source][test]
on_backward_begin
(****kwargs
**:Any
)
No tests found for on_backward_begin
. To contribute a test please refer to this guide and this discussion.
Called after the forward pass and the loss has been computed, but before backprop.
Here is the place to run some code that needs to be executed after the loss has been computed but before the gradient computation.
Example: reg_fn
in RNNs.
At the end of that event the output will be set to last_loss
of the state of the CallbackHandler
.
show_doc(Callback.on_backward_end)
on_backward_end
[source][test]
on_backward_end
(****kwargs
**:Any
)
No tests found for on_backward_end
. To contribute a test please refer to this guide and this discussion.
Called after backprop but before optimizer step. Useful for true weight decay in AdamW.
Here is the place to run some code that needs to be executed after the gradients have been computed but before the optimizer is called.
If skip_step
is True
at the end of this event, the optimizer step is skipped.
show_doc(Callback.on_step_end)
on_step_end
[source][test]
on_step_end
(****kwargs
**:Any
)
No tests found for on_step_end
. To contribute a test please refer to this guide and this discussion.
Called after the step of the optimizer but before the gradients are zeroed.
Here is the place to run some code that needs to be executed after the optimizer step but before the gradients are zeroed.
If skip_zero
is True
at the end of this event, the gradients are not zeroed.
show_doc(Callback.on_batch_end)
on_batch_end
[source][test]
on_batch_end
(****kwargs
**:Any
)
No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
Called at the end of the batch.
Here is the place to run some code that needs to be executed after a batch is fully done. Example: change the values of the hyperparameters (if we don't do it on_batch_begin instead)
If end_epoch
is True
at the end of this event, the current epoch is interrupted (example: lr_finder stops the training when the loss explodes).
show_doc(Callback.on_epoch_end)
on_epoch_end
[source][test]
on_epoch_end
(****kwargs
**:Any
)
No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
Called at the end of an epoch.
Here is the place to run some code that needs to be executed at the end of an epoch. Example: Save the model if we have a new best validation loss/metric.
If end_training
is True
at the end of this event, the training stops (example: early stopping).
show_doc(Callback.on_train_end)
on_train_end
[source][test]
on_train_end
(****kwargs
**:Any
)
No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
Useful for cleaning up things and saving files/models.
Here is the place to tidy everything. It's always executed even if there was an error during the training loop, and has an extra kwarg named exception to check if there was an exception or not. Examples: save log_files, load best model found during training
show_doc(Callback.get_state)
get_state
[source][test]
get_state
(minimal
:bool
=*True
*)
No tests found for get_state
. To contribute a test please refer to this guide and this discussion.
Return the inner state of the Callback
, minimal
or not.
The following functions provide different annealing schedules. You probably won't need to call them directly, but would instead use them as part of a callback. Here's what each one looks like:
annealings = "NO LINEAR COS EXP POLY".split()
fns = [annealing_no, annealing_linear, annealing_cos, annealing_exp, annealing_poly(0.8)]
for fn, t in zip(fns, annealings):
plt.plot(np.arange(0, 100), [fn(2, 1e-2, o)
for o in np.linspace(0.01,1,100)], label=t)
plt.legend();
show_doc(annealing_cos)
annealing_cos
[source][test]
annealing_cos
(start
:Number
,end
:Number
,pct
:float
) →Number
No tests found for annealing_cos
. To contribute a test please refer to this guide and this discussion.
Cosine anneal from start
to end
as pct goes from 0.0 to 1.0.
show_doc(annealing_exp)
annealing_exp
[source][test]
annealing_exp
(start
:Number
,end
:Number
,pct
:float
) →Number
No tests found for annealing_exp
. To contribute a test please refer to this guide and this discussion.
Exponentially anneal from start
to end
as pct goes from 0.0 to 1.0.
show_doc(annealing_linear)
annealing_linear
[source][test]
annealing_linear
(start
:Number
,end
:Number
,pct
:float
) →Number
No tests found for annealing_linear
. To contribute a test please refer to this guide and this discussion.
Linearly anneal from start
to end
as pct goes from 0.0 to 1.0.
show_doc(annealing_no)
annealing_no
[source][test]
annealing_no
(start
:Number
,end
:Number
,pct
:float
) →Number
No tests found for annealing_no
. To contribute a test please refer to this guide and this discussion.
No annealing, always return start
.
show_doc(annealing_poly)
annealing_poly
[source][test]
annealing_poly
(degree
:Number
) →Number
No tests found for annealing_poly
. To contribute a test please refer to this guide and this discussion.
Anneal polynomically from start
to end
as pct goes from 0.0 to 1.0.
show_doc(CallbackHandler)
class
CallbackHandler
[source][test]
CallbackHandler
(callbacks
:Collection
[Callback
]=*None
,metrics
:Collection
[Callback
]=None
,beta
:float
=0.98
*)
No tests found for CallbackHandler
. To contribute a test please refer to this guide and this discussion.
Manage all of the registered callbacks
and metrics
, smoothing loss by momentum beta
.
You probably won't need to use this class yourself. It's used by fastai to combine all the callbacks together and call any relevant callback functions for each training stage. The methods below simply call the equivalent method in each callback function in self.callbacks
.
show_doc(CallbackHandler.on_backward_begin)
on_backward_begin
[source][test]
on_backward_begin
(loss
:Tensor
)
No tests found for on_backward_begin
. To contribute a test please refer to this guide and this discussion.
Handle gradient calculation on loss
.
show_doc(CallbackHandler.on_backward_end)
on_backward_end
[source][test]
on_backward_end
()
No tests found for on_backward_end
. To contribute a test please refer to this guide and this discussion.
Handle end of gradient calculation.
show_doc(CallbackHandler.on_batch_begin)
on_batch_begin
[source][test]
on_batch_begin
(xb
:Tensor
,yb
:Tensor
,train
:bool
=*True
*)
No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
Handle new batch xb
,yb
in train
or validation.
show_doc(CallbackHandler.on_batch_end)
on_batch_end
[source][test]
on_batch_end
(loss
:Tensor
)
No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
Handle end of processing one batch with loss
.
show_doc(CallbackHandler.on_epoch_begin)
on_epoch_begin
[source][test]
on_epoch_begin
()
No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
Handle new epoch.
show_doc(CallbackHandler.on_epoch_end)
on_epoch_end
[source][test]
on_epoch_end
(val_loss
:Tensor
) →bool
No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
Epoch is done, process val_loss
.
show_doc(CallbackHandler.on_loss_begin)
on_loss_begin
[source][test]
on_loss_begin
(out
:Tensor
)
No tests found for on_loss_begin
. To contribute a test please refer to this guide and this discussion.
Handle start of loss calculation with model output out
.
show_doc(CallbackHandler.on_step_end)
on_step_end
[source][test]
on_step_end
()
No tests found for on_step_end
. To contribute a test please refer to this guide and this discussion.
Handle end of optimization step.
show_doc(CallbackHandler.on_train_begin)
on_train_begin
[source][test]
on_train_begin
(epochs
:int
,pbar
:PBar
,metrics
:MetricFuncList
)
No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
About to start learning.
show_doc(CallbackHandler.on_train_end)
on_train_end
[source][test]
on_train_end
(exception
:Union
[bool
,Exception
])
No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
Handle end of training, exception
is an Exception
or False if no exceptions during training.
show_doc(CallbackHandler.set_dl)
set_dl
[source][test]
set_dl
(dl
:DataLoader
)
No tests found for set_dl
. To contribute a test please refer to this guide and this discussion.
Set the current dl
used.
show_doc(OptimWrapper)
class
OptimWrapper
[source][test]
OptimWrapper
(opt
:Optimizer
,wd
:Floats
=*0.0
,true_wd
:bool
=False
,bn_wd
:bool
=True
*)
No tests found for OptimWrapper
. To contribute a test please refer to this guide and this discussion.
Basic wrapper around opt
to simplify hyper-parameters changes.
This is a convenience class that provides a consistent API for getting and setting optimizer hyperparameters. For instance, for optim.Adam
the momentum parameter is actually betas[0]
, whereas for optim.SGD
it's simply momentum
. As another example, the details of handling weight decay depend on whether you are using true_wd
or the traditional L2 regularization approach.
This class also handles setting different WD and LR for each layer group, for discriminative layer training.
show_doc(OptimWrapper.clear)
clear
[source][test]
clear
()
No tests found for clear
. To contribute a test please refer to this guide and this discussion.
Reset the state of the inner optimizer.
show_doc(OptimWrapper.create)
create
[source][test]
create
(opt_func
:Union
[type
,Callable
],lr
:Union
[float
,Tuple
,List
[T
]],layer_groups
:ModuleList
,wd
:Floats
=*0.0
,true_wd
:bool
=False
,bn_wd
:bool
=True
*) →Optimizer
No tests found for create
. To contribute a test please refer to this guide and this discussion.
Create an optim.Optimizer
from opt_func
with lr
. Set lr on layer_groups
.
show_doc(OptimWrapper.new)
new
[source][test]
new
(layer_groups
:ModuleList
)
No tests found for new
. To contribute a test please refer to this guide and this discussion.
Create a new OptimWrapper
from self
with another layer_groups
but the same hyper-parameters.
show_doc(OptimWrapper.read_defaults)
read_defaults
[source][test]
read_defaults
()
No tests found for read_defaults
. To contribute a test please refer to this guide and this discussion.
Read the values inside the optimizer for the hyper-parameters.
show_doc(OptimWrapper.read_val)
read_val
[source][test]
read_val
(key
:str
) →Union
[List
[float
],Tuple
[List
[float
],List
[float
]]]
No tests found for read_val
. To contribute a test please refer to this guide and this discussion.
Read a hyperparameter key
in the optimizer dictionary.
show_doc(OptimWrapper.set_val)
set_val
[source][test]
set_val
(key
:str
,val
:Any
,bn_groups
:bool
=*True
*) →Any
No tests found for set_val
. To contribute a test please refer to this guide and this discussion.
Set val
inside the optimizer dictionary at key
.
show_doc(OptimWrapper.step)
step
[source][test]
step
()
No tests found for step
. To contribute a test please refer to this guide and this discussion.
Set weight decay and step optimizer.
show_doc(OptimWrapper.zero_grad)
zero_grad
[source][test]
zero_grad
()
No tests found for zero_grad
. To contribute a test please refer to this guide and this discussion.
Clear optimizer gradients.
show_doc(SmoothenValue)
class
SmoothenValue
[source][test]
SmoothenValue
(beta
:float
)
No tests found for SmoothenValue
. To contribute a test please refer to this guide and this discussion.
Create a smooth moving average for a value (loss, etc) using beta
.
Used for smoothing loss in Recorder
.
show_doc(SmoothenValue.add_value)
add_value
[source][test]
add_value
(val
:float
)
No tests found for add_value
. To contribute a test please refer to this guide and this discussion.
Add val
to calculate updated smoothed value.
show_doc(Stepper)
class
Stepper
[source][test]
Stepper
(vals
:StartOptEnd
,n_iter
:int
,func
:Optional
[AnnealFunc
]=*None
*)
No tests found for Stepper
. To contribute a test please refer to this guide and this discussion.
Used to "step" from start,end (vals
) over n_iter
iterations on a schedule defined by func
Used for creating annealing schedules, mainly for OneCycleScheduler
.
show_doc(Stepper.step)
step
[source][test]
step
() →Number
No tests found for step
. To contribute a test please refer to this guide and this discussion.
Return next value along annealed schedule.
show_doc(AverageMetric)
class
AverageMetric
[source][test]
AverageMetric
(func
) ::Callback
No tests found for AverageMetric
. To contribute a test please refer to this guide and this discussion.
Wrap a func
in a callback for metrics computation.
See the documentation on metrics
for more information.
You don't call these yourself - they're called by fastai's Callback
system automatically to enable the class's functionality.
show_doc(AverageMetric.on_epoch_begin)
on_epoch_begin
[source][test]
on_epoch_begin
(****kwargs
**)
No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
Set the inner value to 0.
show_doc(AverageMetric.on_batch_end)
on_batch_end
[source][test]
on_batch_end
(last_output
,last_target
, ****kwargs
**)
No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
Update metric computation with last_output
and last_target
.
show_doc(AverageMetric.on_epoch_end)
on_epoch_end
[source][test]
on_epoch_end
(last_metrics
, ****kwargs
**)
No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
Set the final result in last_metrics
.