This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module
, for either the forward or the backward pass.
We'll start by looking at the pre-defined hook ActivationStats
, then we'll see how to create our own.
from fastai.gen_doc.nbdoc import *
from fastai.callbacks.hooks import *
from fastai.train import *
from fastai.vision import *
show_doc(ActivationStats)
class
ActivationStats
[source][test]
ActivationStats
(learn
:Learner
,modules
:Sequence
[Module
]=*None
,do_remove
:bool
=True
*) ::HookCallback
No tests found for ActivationStats
. To contribute a test please refer to this guide and this discussion.
Callback that record the mean and std of activations.
ActivationStats
saves the layer activations in self.stats
for all modules
passed to it. By default it will save activations for all modules. For instance:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
#learn = cnn_learner(data, models.resnet18, callback_fns=ActivationStats)
learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats)
learn.fit(1)
epoch | train_loss | valid_loss |
---|---|---|
1 | 0.112384 | 0.083544 |
The saved stats
is a FloatTensor
of shape (2,num_modules,num_batches)
. The first axis is (mean,stdev)
.
len(learn.data.train_dl),len(learn.activation_stats.modules)
(193, 3)
learn.activation_stats.stats.shape
torch.Size([2, 3, 193])
So this shows the standard deviation (axis0==1
) of 2th last layer (axis1==-2
) for each batch (axis2
):
plt.plot(learn.activation_stats.stats[1][-2].numpy());
show_doc(ActivationStats.hook)
You don't call these yourself - they're called by fastai's Callback
system automatically to enable the class's functionality.
show_doc(ActivationStats.on_train_begin)
on_train_begin
[source][test]
on_train_begin
(****kwargs
**)
No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
Initialize stats.
show_doc(ActivationStats.on_batch_end)
on_batch_end
[source][test]
on_batch_end
(train
, ****kwargs
**)
No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
Take the stored results and puts it in self.stats
show_doc(ActivationStats.on_train_end)
on_train_end
[source][test]
on_train_end
(****kwargs
**)
No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
Polish the final result.
show_doc(Hook)
class
Hook
[source][test]
Hook
(m
:Module
,hook_func
:HookFunc
,is_forward
:bool
=*True
,detach
:bool
=True
*)
No tests found for Hook
. To contribute a test please refer to this guide and this discussion.
Create a hook on m
with hook_func
.
Registers and manually deregisters a PyTorch hook. Your hook_func
will be called automatically when forward/backward (depending on is_forward
) for your module m
is run, and the result of that function is placed in self.stored
.
show_doc(Hook.remove)
remove
[source][test]
remove
()
No tests found for remove
. To contribute a test please refer to this guide and this discussion.
Remove the hook from the model.
Deregister the hook, if not called already.
show_doc(Hooks)
class
Hooks
[source][test]
Hooks
(ms
:ModuleList
,hook_func
:HookFunc
,is_forward
:bool
=*True
,detach
:bool
=True
*)
No tests found for Hooks
. To contribute a test please refer to this guide and this discussion.
Create several hooks on the modules in ms
with hook_func
.
Acts as a Collection
(i.e. len(hooks)
and hooks[i]
) and an Iterator
(i.e. for hook in hooks
) of a group of hooks, one for each module in ms
, with the ability to remove all as a group. Use stored
to get all hook results. hook_func
and is_forward
behavior is the same as Hook
. See the source code for HookCallback
for a simple example.
show_doc(Hooks.remove)
remove
[source][test]
remove
()
No tests found for remove
. To contribute a test please refer to this guide and this discussion.
Remove the hooks from the model.
Deregister all hooks created by this class, if not previously called.
show_doc(hook_output)
Function that creates a Hook
for module
that simply stores the output of the layer.
show_doc(hook_outputs)
hook_outputs
[source][test]
hook_outputs
(modules
:ModuleList
,detach
:bool
=*True
,grad
:bool
=False
*) →Hooks
No tests found for hook_outputs
. To contribute a test please refer to this guide and this discussion.
Return Hooks
that store activations of all modules
in self.stored
Function that creates a Hook
for all passed modules
that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes
is:
def model_sizes(m, size):
x = m(torch.zeros(1, in_channels(m), *size))
return [o.stored.shape for o in hook_outputs(m)]
show_doc(model_sizes)
model_sizes
[source][test]
model_sizes
(m
:Module
,size
:tuple
=*(64, 64)
*) →Tuple
[Sizes
,Tensor
,Hooks
]
No tests found for model_sizes
. To contribute a test please refer to this guide and this discussion.
Pass a dummy input through the model m
to get the various sizes of activations.
show_doc(model_summary)
model_summary
[source][test]
model_summary
(m
:Learner
,n
:int
=*70
*)
Tests found for model_summary
:
pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab
[source]To run tests please refer to this guide.
Print a summary of m
using a output text width of n
chars
show_doc(num_features_model)
num_features_model
[source][test]
num_features_model
(m
:Module
) →int
No tests found for num_features_model
. To contribute a test please refer to this guide and this discussion.
Return the number of output features for model
.
It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet
), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size
.
show_doc(dummy_batch)
dummy_batch
[source][test]
dummy_batch
(m
:Module
,size
:tuple
=*(64, 64)
*) →Tensor
No tests found for dummy_batch
. To contribute a test please refer to this guide and this discussion.
Create a dummy batch to go through m
with size
.
show_doc(dummy_eval)
dummy_eval
[source][test]
dummy_eval
(m
:Module
,size
:tuple
=*(64, 64)
*)
No tests found for dummy_eval
. To contribute a test please refer to this guide and this discussion.
Pass a dummy_batch
in evaluation mode in m
with size
.
show_doc(HookCallback)
class
HookCallback
[source][test]
HookCallback
(learn
:Learner
,modules
:Sequence
[Module
]=*None
,do_remove
:bool
=True
*) ::LearnerCallback
No tests found for HookCallback
. To contribute a test please refer to this guide and this discussion.
Callback that can be used to register hooks on modules
. Implement the corresponding function in self.hook
.
For all modules
, uses a callback to automatically register a method self.hook
(that you must define in an inherited class) as a hook. This method must have the signature:
def hook(self, m:Model, input:Tensors, output:Tensors)
If do_remove
then the hook is automatically deregistered at the end of training. See ActivationStats
for a simple example of inheriting from this class.
You don't call these yourself - they're called by fastai's Callback
system automatically to enable the class's functionality.
show_doc(HookCallback.on_train_begin)
on_train_begin
[source][test]
on_train_begin
(****kwargs
**)
No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
Register the Hooks
on self.modules
.
show_doc(HookCallback.on_train_end)
on_train_end
[source][test]
on_train_end
(****kwargs
**)
No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
Remove the Hooks
.
show_doc(HookCallback.remove)
remove
[source][test]
remove
()
No tests found for remove
. To contribute a test please refer to this guide and this discussion.
show_doc(Hook.hook_fn)
hook_fn
[source][test]
hook_fn
(module
:Module
,input
:Tensors
,output
:Tensors
)
No tests found for hook_fn
. To contribute a test please refer to this guide and this discussion.
Applies hook_func
to module
, input
, output
.