In [1]:

```
try:
import torchbearer
except:
!pip install -q torchbearer
import torchbearer
print(torchbearer.__version__)
```

Let's assume we have a basic binary classification task where we have 100-dimensional samples as input and a binary label as output. Let's also assume that we would like to solve this problem with a 2-layer neural network. Finally, we also want to keep track of the sum of hidden outputs for some arbitrary reason. Therefore we use the state functionality of Torchbearer.

We create a state key for the mock sum we wanted to track using state.

In [2]:

```
MOCK = torchbearer.state_key('mock')
```

Here is our basic 2-layer neural network.

In [3]:

```
import torch
import torch.nn as nn
class BasicModel(nn.Module):
def __init__(self):
super(BasicModel, self).__init__()
self.linear1 = nn.Linear(100, 25)
self.linear2 = nn.Linear(25, 1)
def forward(self, x, state):
x = self.linear1(x)
# The following step is here to showcase a useless but simple of example a forward method that uses state
state[MOCK] = torch.sum(x)
x = self.linear2(x)
return torch.sigmoid(x)
```

We create some random training dataset and put them in a DataLoader.

In [4]:

```
from torch.utils.data import TensorDataset, DataLoader
n_sample = 100
X = torch.rand(n_sample, 100)
y = torch.randint(0, 2, [n_sample, 1]).float()
traingen = DataLoader(TensorDataset(X, y))
```

Let's say we would like to save the model every time we get a better training loss. Torchbearer's `Best`

checkpoint callback is perfect for this job. We then run the model for 3 epochs.

In [5]:

```
import torch.optim as optim
import torch.nn.functional as F
from torchbearer import Trial
model = BasicModel()
# Create a checkpointer that track val_loss and saves a model.pt whenever we get a better loss
checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss')
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],
callbacks=[checkpointer])
torchbearer_trial.with_train_generator(traingen)
_ = torchbearer_trial.run(epochs=3)
```

Given we recreate the exact same Trial structure, we can easily resume our run from the last checkpoint. The following code block shows how it's done. Remember here that the `epochs`

parameter we pass to Trial acts cumulative. In other words, the following run will complement the entire training to a total of 6 epochs.

In [6]:

```
state_dict = torch.load('model.pt')
model = BasicModel()
trial_reloaded = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],
callbacks=[checkpointer])
trial_reloaded.load_state_dict(state_dict)
trial_reloaded.with_train_generator(traingen)
_ = trial_reloaded.run(epochs=6)
```

In [7]:

```
model = BasicModel()
try:
model.load_state_dict(state_dict)
except AttributeError as e:
print("\n")
print(e)
```

This gives an error. The reason is that the `state_dict`

has Trial related attributes that are unknown to a native PyTorch model. This is why we have the `save_model_params_only`

option for our checkpointers. We try again with that option

In [9]:

```
model = BasicModel()
checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss', save_model_params_only=True)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],
callbacks=[checkpointer])
torchbearer_trial.with_train_generator(traingen)
torchbearer_trial.run(epochs=3)
# Try once again to load the module, forward another random sample for testing
state_dict = torch.load('model.pt')
model = BasicModel()
_ = model.load_state_dict(state_dict)
```

No errors this time, but we still have to test. Here is a test sample and we run it through the model.

In [10]:

```
X_test = torch.rand(5, 100)
try:
model(X_test)
except TypeError as e:
print("\n")
print(e)
```

Now we get a different error, stating that we should also be passing `state`

as an argument to module's forward. This should not be a surprise as we defined `state`

parameter in the forward method of `BasicModule`

as a required argument.

We define the model with a better signature this time, so it gracefully handles the problem above.

In [11]:

```
class BetterSignatureModel(nn.Module):
def __init__(self):
super(BetterSignatureModel, self).__init__()
self.linear1 = nn.Linear(100, 25)
self.linear2 = nn.Linear(25, 1)
def forward(self, x, **state):
x = self.linear1(x)
# Using kwargs instead of state is safer from a serialization perspective
if state is not None:
state = state
state[MOCK] = torch.sum(x)
x = self.linear2(x)
return torch.sigmoid(x)
```

Finally, we wrap it up once again to test the new definition of the model.

In [12]:

```
model = BetterSignatureModel()
checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss', save_model_params_only=True)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],
callbacks=[checkpointer])
torchbearer_trial.with_train_generator(traingen)
torchbearer_trial.run(epochs=3)
# This time, the forward function should work without the need for a state argument
state_dict = torch.load('model.pt')
model = BetterSignatureModel()
model.load_state_dict(state_dict)
X_test = torch.rand(5, 100)
model(X_test)
```

Out[12]:

In [ ]:

```
```