This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module
, for either the forward or the backward pass.
We'll start by looking at the pre-defined hook ActivationStats
, then we'll see how to create our own.
from fastai.gen_doc.nbdoc import *
from fastai.callbacks.hooks import *
from fastai import *
from fastai.train import *
from fastai.vision import *
show_doc(ActivationStats)
class
ActivationStats
[source]
ActivationStats
(learn
:Learner
,modules
:Sequence
[Module
]=None
,do_remove
:bool
=True
) ::HookCallback
Callback that record the activations.
ActivationStats
saves the layer activations in self.stats
for all modules
passed to it. By default it will save activations for all modules. For instance:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats)
learn.fit(1)
VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…
Total time: 00:13 epoch train loss valid loss 0 0.077055 0.049985 (00:13)
The saved stats
is a FloatTensor
of shape (2,num_modules,num_batches)
. The first axis is (mean,stdev)
.
len(learn.data.train_dl),len(learn.activation_stats.modules)
(194, 44)
learn.activation_stats.stats.shape
torch.Size([2, 44, 194])
So this shows the standard deviation (axis0==1
) of 5th last layer (axis1==-5
) for each batch (axis2
):
plt.plot(learn.activation_stats.stats[1][-5].numpy());
show_doc(Hook)
Registers and manually deregisters a PyTorch hook. Your hook_func
will be called automatically when forward/backward (depending on is_forward
) for your module m
is run, and the result of that function is placed in self.stored
.
show_doc(Hook.remove)
remove
[source]
remove
()
Deregister the hook, if not called already.
show_doc(Hooks)
class
Hooks
[source]
Hooks
(ms
:ModuleList
,hook_func
:HookFunc
,is_forward
:bool
=True
,detach
:bool
=True
)
Create several hooks.
Acts as a Collection
(i.e. len(hooks)
and hooks[i]
) and an Iterator
(i.e. for hook in hooks
) of a group of hooks, one for each module in ms
, with the ability to remove all as a group. Use stored
to get all hook results. hook_func
and is_forward
behavior is the same as Hook
. See the source code for HookCallback
for a simple example.
show_doc(Hooks.remove)
remove
[source]
remove
()
Deregister all hooks created by this class, if not previously called.
Function that creates a Hook
for module
that simply stores the output of the layer.
Function that creates a Hook
for all passed modules
that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes
is:
def model_sizes(m, size):
x = m(torch.zeros(1, in_channels(m), *size))
return [o.stored.shape for o in hook_outputs(m)]
show_doc(model_sizes)
show_doc(num_features_model)
It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet
), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size
.
show_doc(HookCallback)
class
HookCallback
[source]
HookCallback
(learn
:Learner
,modules
:Sequence
[Module
]=None
,do_remove
:bool
=True
) ::LearnerCallback
Callback that registers given hooks.
For all modules
, uses a callback to automatically register a method self.hook
(that you must define in an inherited class) as a hook. This method must have the signature:
def hook(self, m:Model, input:Tensors, output:Tensors)
If do_remove
then the hook is automatically deregistered at the end of training. See ActivationStats
for a simple example of inheriting from this class.
show_doc(HookCallback.remove)
remove
[source]
remove
()
show_doc(HookCallback.on_train_begin)
show_doc(HookCallback.on_train_end)
show_doc(ActivationStats.hook)
show_doc(ActivationStats.on_batch_end)
show_doc(ActivationStats.on_train_begin)
show_doc(ActivationStats.on_train_end)
show_doc(Hook.hook_fn)