from fastai.basic_train import *
from fastai.gen_doc.nbdoc import *
from fastai.vision import *
from fastai.distributed import *
basic_train
wraps together the data (in a DataBunch
object) with a PyTorch model to define a Learner
object. Here the basic training loop is defined for the fit
method. The Learner
object is the entry point of most of the Callback
objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the train
module, notably:
Learner.lr_find
will launch an LR range test that will help you select a good learning rate.Learner.fit_one_cycle
will launch a training using the 1cycle policy to help you train your model faster.Learner.to_fp16
will convert your model to half precision and help you launch a training in mixed precision.show_doc(Learner, title_level=2)
class
Learner
[source]
Learner
(data
:DataBunch
,model
:Module
,opt_func
:Callable
=*'Adam'
,loss_func
:Callable
=None
,metrics
:Collection
[Callable
]=None
,true_wd
:bool
=True
,bn_wd
:bool
=True
,wd
:Floats
=0.01
,train_bn
:bool
=True
,path
:str
=None
,model_dir
:str
='models'
,callback_fns
:Collection
[Callable
]=None
,callbacks
:Collection
[Callback
]=<factory>
,layer_groups
:ModuleList
=None
*)
Trainer for model
using data
to minimize loss_func
with optimizer opt_func
.
The main purpose of Learner
is to train model
using Learner.fit
. After every epoch, all metrics will be printed and also made available to callbacks.
The default weight decay will be wd
, which will be handled using the method from Fixing Weight Decay Regularization in Adam if true_wd
is set (otherwise it's L2 regularization). If bn_wd
is False
, then weight decay will be removed from batchnorm layers, as recommended in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. If train_bn
, batchnorm layer learnable params are trained even for frozen layer groups.
To use discriminative layer training, pass a list of nn.Module
as layer_groups
; each nn.Module
will be used to customize the optimization of the corresponding layer group.
If path
is provided, all the model files created will be saved in path
/model_dir
; if not, then they will be saved in data.path
/model_dir
.
You can pass a list of callback
s that you have already created, or (more commonly) simply pass a list of callback functions to callback_fns
and each function will be called (passing self
) on object initialization, with the results stored as callback objects. For a walk-through, see the training overview page. You may also want to use an application specific model. For example, if you are dealing with a vision dataset, here the MNIST, you might want to use the create_cnn
method:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, metrics=accuracy)
show_doc(Learner.lr_find)
Runs the learning rate finder defined in LRFinder
, as discussed in Cyclical Learning Rates for Training Neural Networks.
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
Min numerical gradient: 1.32E-02
show_doc(Learner.fit)
Uses discriminative layer training if multiple learning rates or weight decay values are passed. To control training behaviour, use the callback
system or one or more of the pre-defined callbacks
.
learn.fit(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.129607 | 0.082084 | 0.973013 |
show_doc(Learner.fit_one_cycle)
fit_one_cycle
[source]
fit_one_cycle
(learn
:Learner
,cyc_len
:int
,max_lr
:Union
[float
,Collection
[float
],slice
]=*slice(None, 0.003, None)
,moms
:Point
=(0.95, 0.85)
,div_factor
:float
=25.0
,pct_start
:float
=0.3
,wd
:float
=None
,callbacks
:Optional
[Collection
[Callback
]]=None
,tot_epochs
:int
=None
,start_epoch
:int
=1
*)
Fit a model following the 1cycle policy.
Use cycle length cyc_len
, a per cycle maximal learning rate max_lr
, momentum moms
, division factor div_factor
, weight decay wd
, and optional callbacks callbacks
. Uses the OneCycleScheduler
callback. Please refer to What is 1-cycle for a conceptual background of 1-cycle training policy and more technical details on what do the method's arguments do.
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.088884 | 0.066379 | 0.978410 |
show_doc(Learner.predict)
predict
can be used to get a single prediction from the trained learner on one specific piece of data you are interested in.
learn.data.train_ds[0]
(Image (3, 28, 28), Category 3)
Each element of the dataset is a tuple, where the first element is the data itself, while the second element is the target label. So to get the data, we need to index one more time.
data = learn.data.train_ds[0][0]
data
pred = learn.predict(data)
pred
(Category 3, tensor(0), tensor([9.9979e-01, 2.0649e-04]))
The first two elements of the tuple are, respectively, the predicted class and label. Label here is essentially an internal representation of each class, since class name is a string and cannot be used in computation. To check what each label corresponds to, run:
learn.data.classes
['3', '7']
So category 0 is 3 while category 1 is 7.
probs = pred[2]
The last element in the tuple is the predicted probabilities. For a categorization dataset, the number of probabilities returned is the same as the number of classes; probs[i]
is the probability that the item
belongs to learn.data.classes[i]
.
learn.data.valid_ds[0][0]
You could always check yourself if the probabilities given make sense.
show_doc(Learner.get_preds)
get_preds
[source]
get_preds
(ds_type
:DatasetType
=*<DatasetType.Valid: 2>
,with_loss
:bool
=False
,n_batch
:Optional
[int
]=None
,pbar
:Union
[MasterBar
,ProgressBar
,NoneType
]=None
*) →List
[Tensor
]
Return predictions and targets on ds_type
dataset.
It will run inference using the learner on all the data in the ds_type
dataset and return the predictions; if n_batch
is not specified, it will run the predictions on the default batch size. If with_loss
, it will also return the loss on each prediction.
Here is how you check the default batch size.
learn.data.batch_size
64
preds = learn.get_preds()
preds
[tensor([[9.9366e-01, 6.3430e-03], [9.9828e-01, 1.7193e-03], [9.9993e-01, 7.1130e-05], ..., [1.5793e-04, 9.9984e-01], [9.0569e-03, 9.9094e-01], [9.8014e-01, 1.9864e-02]]), tensor([0, 0, 0, ..., 1, 1, 1])]
The first element of the tuple is a tensor that contains all the predictions.
preds[0]
tensor([[9.9366e-01, 6.3430e-03], [9.9828e-01, 1.7193e-03], [9.9993e-01, 7.1130e-05], ..., [1.5793e-04, 9.9984e-01], [9.0569e-03, 9.9094e-01], [9.8014e-01, 1.9864e-02]])
While the second element of the tuple is a tensor that contains all the target labels.
preds[1]
tensor([0, 0, 0, ..., 1, 1, 1])
preds[1][0]
tensor(0)
len(learn.data.valid_ds)
2038
len(preds[0]), len(preds[1])
(2038, 2038)
To get predictions on the entire training dataset, simply set the ds_type
argument accordingly.
learn.get_preds(ds_type=DatasetType.Train)
[tensor([[9.9973e-01, 2.6554e-04], [9.9962e-01, 3.8422e-04], [9.9988e-01, 1.1570e-04], ..., [9.9922e-01, 7.8436e-04], [4.4838e-04, 9.9955e-01], [1.3715e-04, 9.9986e-01]]), tensor([0, 0, 0, ..., 0, 1, 1])]
To also get prediction loss along with the predictions and the targets, set with_loss=True
in the arguments.
learn.get_preds(with_loss=True)
[tensor([[9.9366e-01, 6.3430e-03], [9.9828e-01, 1.7193e-03], [9.9993e-01, 7.1130e-05], ..., [1.5793e-04, 9.9984e-01], [9.0569e-03, 9.9094e-01], [9.8014e-01, 1.9864e-02]]), tensor([0, 0, 0, ..., 1, 1, 1]), tensor([6.3632e-03, 1.7209e-03, 7.1049e-05, ..., 1.5783e-04, 9.0983e-03, 3.9189e+00])]
Note that the third tensor in the output tuple contains the losses.
show_doc(Learner.validate)
validate
[source]
validate
(dl
=*None
,callbacks
=None
,metrics
=None
*)
Validate on dl
with potential callbacks
and metrics
.
Return the calculated loss and the metrics of the current model on the given data loader dl
. The default data loader dl
is the validation dataloader.
You can check the default metrics of the learner using:
str(learn.metrics)
'[<function accuracy at 0x7f1effc86d08>]'
learn.validate()
[0.06637867, tensor(0.9784)]
learn.validate(learn.data.valid_dl)
[0.06637867, tensor(0.9784)]
learn.validate(learn.data.train_dl)
[0.039573476, tensor(0.9860)]
show_doc(Learner.show_results)
show_results
[source]
show_results
(ds_type
=*<DatasetType.Valid: 2>
,rows
:int
=5
, ***kwargs
**)
Show rows
result of predictions on ds_type
dataset.
Note that the text number on the top is the ground truth, or the target label, the one in the middle is the prediction, while the image number on the bottom is the image data itself.
learn.show_results()
learn.show_results(ds_type=DatasetType.Train)
show_doc(Learner.pred_batch)
pred_batch
[source]
pred_batch
(ds_type
:DatasetType
=*<DatasetType.Valid: 2>
,batch
:Tuple
=None
,reconstruct
:bool
=False
*) →List
[Tensor
]
Return output of the model on one batch from ds_type
dataset.
Note that the number of predictions given equals to the batch size.
learn.data.batch_size
64
preds = learn.pred_batch()
len(preds)
64
Since the total number of predictions is too large, we will only look at a part of them.
preds[:10]
tensor([[9.9366e-01, 6.3430e-03], [9.9828e-01, 1.7193e-03], [9.9993e-01, 7.1130e-05], [1.0000e+00, 5.2653e-07], [9.9839e-01, 1.6092e-03], [1.0000e+00, 9.6659e-07], [9.5156e-01, 4.8442e-02], [9.9854e-01, 1.4628e-03], [9.9937e-01, 6.2854e-04], [8.3490e-01, 1.6510e-01]])
item = learn.data.train_ds[0][0]
item
batch = learn.data.one_item(item)
batch
(tensor([[[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], [[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], [[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]], device='cuda:0'), tensor([0], device='cuda:0'))
learn.pred_batch(batch=batch)
tensor([[9.9979e-01, 2.0649e-04]])
show_doc(Learner.interpret, full_name='interpret')
interpret
[source]
interpret
(learn
:Learner
,ds_type
:DatasetType
=*<DatasetType.Valid: 2>
,tta
=False
*)
Create a ClassificationInterpretation
object from learner
on ds_type
with tta
.
jekyll_note('This function only works in the vision application.')
For more details, refer to ClassificationInterpretation
show_doc(Learner.summary)
show_doc(Learner.TTA, full_name = 'TTA')
TTA
[source]
TTA
(learn
:Learner
,beta
:float
=*0.4
,scale
:float
=1.35
,ds_type
:DatasetType
=<DatasetType.Valid: 2>
,with_loss
:bool
=False
*) →Tensors
Applies TTA to predict on ds_type
dataset.
Applies Test Time Augmentation to learn
on the dataset ds_type
. We take the average of our regular predictions (with a weight beta
) with the average of predictions obtained through augmented versions of the training set (with a weight 1-beta
). The transforms decided for the training set are applied with a few changes scale
controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total).
show_doc(Learner.clip_grad)
show_doc(Learner.to_fp16)
Uses the MixedPrecision
callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all NVIDIA recommendations for ensuring speed and accuracy.
show_doc(Learner.to_fp32)
show_doc(Learner.distributed, full_name='distributed')
When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each layer group (i.e. the parameters of each module in self.layer_groups
). See the Universal Language Model Fine-tuning for Text Classification paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a Learner
on which you've called split
, you can set hyperparameters in four ways:
param = [val1, val2 ..., valn]
(n = number of layer groups)param = val
param = slice(start,end)
param = slice(end)
If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See Learner.lr_range
for an explanation of the slice
syntax).
Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call Learner.split
in this case, since fastai uses this exact function as the default split for resnet18
; this is just to show how to customize it):
# creates 3 layer groups
learn.split(lambda m: (m[0][6], m[1]))
# only randomly initialized head now trainable
learn.freeze()
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.067769 | 0.060910 | 0.979392 |
# all layers now trainable
learn.unfreeze()
# optionally, separate LR and WD for each group
learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.022366 | 0.006872 | 0.998037 |
show_doc(Learner.lr_range)
lr_range
[source]
lr_range
(lr
:Union
[float
,slice
]) →ndarray
Build differential learning rates from lr
.
Rather than manually setting an LR for every group, it's often easier to use Learner.lr_range
. This is a convenience method that returns one learning rate for each layer group. If you pass slice(start,end)
then the first group's learning rate is start
, the last is end
, and the remaining are evenly geometrically spaced.
If you pass just slice(end)
then the last group's learning rate is end
, and all the other groups are end/10
. For instance (for our learner that has 3 layer groups):
learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))
(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))
show_doc(Learner.unfreeze)
Sets every layer group to trainable (i.e. requires_grad=True
).
show_doc(Learner.freeze)
Sets every layer group except the last to untrainable (i.e. requires_grad=False
).
show_doc(Learner.freeze_to)
show_doc(Learner.split)
A convenience method that sets layer_groups
based on the result of split_model
. If split_on
is a function, it calls that function and passes the result to split_model
(see above for example).
Simply call Learner.save
and Learner.load
to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the path
/model_dir
directory.
show_doc(Learner.save)
save
[source]
save
(name
:PathOrStr
,return_path
:bool
=*False
,with_opt
:bool
=True
*)
Save model and optimizer state (if with_opt
) with name
to self.model_dir
.
learn.save("trained_model")
learn.save("trained_model", return_path=True)
PosixPath('/home/jupyter/.fastai/data/mnist_sample/models/trained_model.pth')
show_doc(Learner.load)
learn = learn.load("trained_model")
When you are ready to put your model in production, export the minimal state of your Learner
with
show_doc(Learner.export)
learn.export()
learn.export('trained_model.pkl')
path = learn.path
path
PosixPath('/home/jupyter/.fastai/data/mnist_sample')
show_doc(load_learner)
learn = load_learner(path)
learn = load_learner(path, fname='trained_model.pkl')
WARNING: If you used any customized classes when creating your learner, you must first define these classes first before executing load_learner
.
You can find more information and multiple examples in this tutorial
show_doc(Learner.init)
init
[source]
init
(init
)
Initializes all weights (except batchnorm) using function init
, which will often be from PyTorch's nn.init
module.
show_doc(Learner.mixup)
mixup
[source]
mixup
(learn
:Learner
,alpha
:float
=*0.4
,stack_x
:bool
=False
,stack_y
:bool
=True
*) →Learner
Add mixup https://arxiv.org/abs/1710.09412 to learn
.
Uses MixUpCallback
.
show_doc(Learner.backward)
backward
[source]
backward
(item
)
Pass item
through the model and computes the gradient. Useful if backward_hooks
are attached.
show_doc(Learner.create_opt)
create_opt
[source]
create_opt
(lr
:Floats
,wd
:Floats
=*0.0
*)
Create optimizer with lr
learning rate and wd
weight decay.
You generally won't need to call this yourself - it's used to create the optim
optimizer before fitting the model.
show_doc(Learner.dl)
dl
[source]
dl
(ds_type
:DatasetType
=*<DatasetType.Valid: 2>
*)
Return DataLoader for DatasetType ds_type
.
learn.dl()
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f1efe504780>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f1f16f140d0>)
learn.dl(DatasetType.Train)
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f1f696aa4a8>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f1f16f140d0>)
show_doc(Recorder, title_level=2)
class
Recorder
[source]
Recorder
(learn
:Learner
) ::LearnerCallback
A LearnerCallback
that records epoch, loss, opt and metric data during training.
A Learner
creates a Recorder
object automatically - you do not need to explicitly pass it to callback_fns
- because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics for each batch, and provides plotting methods for each. Note that Learner
automatically sets an attribute with the snake-cased name of each callback, so you can access this through Learner.recorder
, as shown below.
show_doc(Recorder.plot)
plot
[source]
plot
(skip_start
:int
=*10
,skip_end
:int
=5
*)
Plot learning rate and losses, trimmed between skip_start
and skip_end
. Optionally plot and return min gradient
This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph. Min numerical gradient: 7.59E-03
show_doc(Recorder.plot_losses)
Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch.
learn.fit_one_cycle(5)
learn.recorder.plot_losses()
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.247247 | 0.141247 | 0.954367 |
2 | 0.109672 | 0.078876 | 0.972522 |
3 | 0.065391 | 0.054635 | 0.983808 |
4 | 0.044042 | 0.049592 | 0.981845 |
5 | 0.041287 | 0.049224 | 0.984298 |
show_doc(Recorder.plot_lr)
learn.recorder.plot_lr()
learn.recorder.plot_lr(show_moms=True)
show_doc(Recorder.plot_metrics)
Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here.
learn.recorder.plot_metrics()
show_doc(Recorder.on_backward_begin)
on_backward_begin
[source]
on_backward_begin
(smooth_loss
:Tensor
, ****kwargs
**:Any
)
Record the loss before any other callback has a chance to modify it.
show_doc(Recorder.on_batch_begin)
on_batch_begin
[source]
on_batch_begin
(train
, ****kwargs
**:Any
)
Record learning rate and momentum at beginning of batch.
show_doc(Recorder.on_epoch_end)
on_epoch_end
[source]
on_epoch_end
(epoch
:int
,num_batch
:int
,smooth_loss
:Tensor
,last_metrics
=*typing.Collection[typing.Union[torch.Tensor, numbers.Number]]
, ***kwargs
**:Any
) →bool
Save epoch info: num_batch, smooth_loss, metrics.
show_doc(Recorder.on_train_begin)
on_train_begin
[source]
on_train_begin
(pbar
:PBar
,metrics_names
:StrList
, ****kwargs
**:Any
)
Initialize recording status at beginning of training.
The following functions are used along the way by the Recorder
or can be called by other callbacks.
show_doc(Recorder.add_metrics)
show_doc(Recorder.add_metric_names)
show_doc(Recorder.format_stats)
show_doc(fit)
Note that you have to create the Optimizer
yourself if you call this function, whereas Learn.fit
creates it for you automatically.
show_doc(train_epoch)
train_epoch
[source]
train_epoch
(model
:Module
,dl
:DataLoader
,opt
:Optimizer
,loss_func
:LossFunction
)
Simple training of model
for 1 epoch of dl
using optim opt
and loss function loss_func
.
You won't generally need to call this yourself - it's what fit
calls for each epoch.
show_doc(validate)
validate
[source]
validate
(model
:Module
,dl
:DataLoader
,loss_func
:OptLossFunc
=*None
,cb_handler
:Optional
[CallbackHandler
]=None
,pbar
:Union
[MasterBar
,ProgressBar
,NoneType
]=None
,average
=True
,n_batch
:Optional
[int
]=None
*) →Iterator
[Tuple
[IntOrTensor
,Ellipsis
]]
Calculate loss_func
of model
on dl
in evaluation mode.
This is what fit
calls after each epoch. You can call it if you want to run inference on a DataLoader
manually.
show_doc(get_preds)
get_preds
[source]
get_preds
(model
:Module
,dl
:DataLoader
,pbar
:Union
[MasterBar
,ProgressBar
,NoneType
]=*None
,cb_handler
:Optional
[CallbackHandler
]=None
,activ
:Module
=None
,loss_func
:OptLossFunc
=None
,n_batch
:Optional
[int
]=None
*) →List
[Tensor
]
Tuple of predictions and targets, and optional losses (if loss_func
) using dl
, max batches n_batch
.
show_doc(loss_batch)
loss_batch
[source]
loss_batch
(model
:Module
,xb
:Tensor
,yb
:Tensor
,loss_func
:OptLossFunc
=*None
,opt
:OptOptimizer
=None
,cb_handler
:Optional
[CallbackHandler
]=None
*) →Tuple
[Union
[Tensor
,int
,float
,str
]]
Calculate loss and metrics for a batch, call out to callbacks as necessary.
show_doc(LearnerCallback, title_level=3)
show_doc(RecordOnCPU, title_level=3)
show_doc(Learner.tta_only)
_tta_only
[source]
_tta_only
(learn
:Learner
,ds_type
:DatasetType
=*<DatasetType.Valid: 2>
,scale
:float
=1.35
*) →Iterator
[List
[Tensor
]]
Computes the outputs for several augmented inputs for TTA
show_doc(Learner.TTA)
_TTA
[source]
_TTA
(learn
:Learner
,beta
:float
=*0.4
,scale
:float
=1.35
,ds_type
:DatasetType
=<DatasetType.Valid: 2>
,with_loss
:bool
=False
*) →Tensors
Applies TTA to predict on ds_type
dataset.
show_doc(RecordOnCPU.on_batch_begin)
on_batch_begin
[source]
on_batch_begin
(last_input
,last_target
, ****kwargs
**)
Set HP before the step is done. Returns xb, yb (which can allow us to modify the input at that step if needed).
show_doc(Learner.purge)