#!/usr/bin/env python # coding: utf-8 # ## Additional training functions # [`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks: # # - [`ShowGraph`](/train.html#ShowGraph) # - [`GradientClipping`](/train.html#GradientClipping) # - [`BnFreeze`](/train.html#BnFreeze) # - [`AccumulateScheduler`](/train.html#AccumulateScheduler) # In[1]: from fastai.gen_doc.nbdoc import * from fastai.train import * from fastai.vision import * # ## [`Learner`](/basic_train.html#Learner) extension methods # These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created. # In[2]: show_doc(fit_one_cycle) # In[3]: show_doc(one_cycle_scheduler) # See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details. # In[4]: show_doc(lr_find) # See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details. # In[5]: show_doc(to_fp16) # See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details. # In[6]: show_doc(to_fp32) # In[7]: show_doc(mixup) # See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details. # In[8]: show_doc(Interpretation) # In[9]: show_doc(Interpretation.from_learner) # In[10]: show_doc(Interpretation.top_losses) # For example in [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) is implemented using argmax on preds to set `self.pred_class` whereas an optional sigmoid is used for `MultilabelClassificationInterpretation` # In[11]: show_doc(ClassificationInterpretation) # In[ ]: path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path) learn = cnn_learner(data, models.resnet18) learn.fit(1) preds,y,losses = learn.get_preds(with_loss=True) interp = ClassificationInterpretation(learn, preds, y, losses) # In[12]: show_doc(ClassificationInterpretation.top_losses) # Returns tuple of *(losses,indices)*. # In[ ]: interp.top_losses(9) # In[13]: show_doc(ClassificationInterpretation.plot_confusion_matrix) # If [`normalize`](/vision.data.html#normalize), plots the percentages with `norm_dec` digits. `slice_size` can be used to avoid out of memory error if your set is too big. `kwargs` are passed to `plt.figure`. # In[ ]: interp.plot_confusion_matrix() # In[14]: show_doc(ClassificationInterpretation.confusion_matrix) # In[ ]: interp.confusion_matrix() # In[15]: show_doc(ClassificationInterpretation.most_confused) # In[16]: show_doc(MultiLabelClassificationInterpretation) # In[ ]: jekyll_warn("MultiLabelClassificationInterpretation is not implemented yet. Feel free to implement it :)") # #### Working with large datasets # When working with large datasets, memory problems can arise when computing the confusion matrix. For example, an error can look like this: # # RuntimeError: $ Torch: not enough memory: you tried to allocate 64GB. Buy new RAM! # # In this case it is possible to force [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) to compute the confusion matrix for data slices and then aggregate the result by specifying slice_size parameter. # In[ ]: interp.confusion_matrix(slice_size=10) # In[ ]: interp.plot_confusion_matrix(slice_size=10) # In[ ]: interp.most_confused(slice_size=10) # ## Additional callbacks # We'll show examples below using our MNIST sample. As usual the `on_something` methods are directly called by the fastai library, no need to call them yourself. # In[ ]: path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path) # In[17]: show_doc(ShowGraph, title_level=3) # ```python # learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph) # learn.fit(3) # ``` # ![Training graph](imgs/train_graph.gif) # In[18]: show_doc(ShowGraph.on_epoch_end) # In[19]: show_doc(GradientClipping) # In[ ]: learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=partial(GradientClipping, clip=0.1)) learn.fit(1) # In[20]: show_doc(GradientClipping.on_backward_end) # In[21]: show_doc(BnFreeze) # For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers). # In[ ]: learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze) learn.fit(1) # In[22]: show_doc(BnFreeze.on_epoch_begin) # In[23]: show_doc(AccumulateScheduler) # Let's force `batch_size=2` to mimic a scenario where we can't fit enough batch samples to our memory. We can then set `n_step` as desired to have an effective batch_size of `effective_batch_size=batch_size*n_step`. # # It is also important to use loss func with `reduce='sum'` in order to calculate exact average accumulated gradients. # # Another important note for users is that `batchnorm` is not yet adapted to accumulated gradients. So you should use this callback at your own risk until a hero fixes it :) # # Here we demonstrate this callback with a model without `batchnorm` layers, alternatively you can use `nn.InstanceNorm` or [`nn.GroupNorm`](https://pytorch.org/docs/stable/nn.html#torch.nn.GroupNorm). # # ``` # from torchvision.models import vgg11 # # data = ImageDataBunch.from_folder(path, bs=2) # # learn = cnn_learner(data, resnet18, metrics=accuracy, loss_func=CrossEntropyFlat(reduction='sum'), # callback_fns=partial(AccumulateScheduler, n_step=16)) # learn.fit(1) # ``` # ## Undocumented Methods - Methods moved below this line will intentionally be hidden # ## New Methods - Please document or move to the undocumented section # In[24]: show_doc(ClassificationInterpretation.plot_top_losses) # # In[25]: show_doc(ClassificationInterpretation.from_learner) # # In[26]: show_doc(ClassificationInterpretation.top_losses) # # In[27]: show_doc(ClassificationInterpretation.confusion_matrix) # # In[28]: show_doc(ClassificationInterpretation.most_confused) # # In[29]: show_doc(ClassificationInterpretation.plot_confusion_matrix) # # In[30]: show_doc(ClassificationInterpretation.plot_multi_top_losses) # # ## Open This Notebook # #