This notebook shows some more advanced features of skorch
. More examples will be added with time.
![]() |
![]() |
import subprocess
# Installation on Google Colab
try:
import google.colab
subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch'])
except ImportError:
pass
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(0)
torch.cuda.manual_seed(0)
We load a toy classification task from sklearn
.
import numpy as np
from sklearn.datasets import make_classification
np.random.seed(0)
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X.shape, y.shape, y.mean()
((1000, 20), (1000,), 0.5)
pytorch
classification module
¶We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling predict_proba
, the output from the forward
call will be used.
from skorch import NeuralNetClassifier
class ClassifierModule(nn.Module):
def __init__(
self,
num_units=10,
nonlin=F.relu,
dropout=0.5,
):
super(ClassifierModule, self).__init__()
self.num_units = num_units
self.nonlin = nonlin
self.dropout = dropout
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(dropout)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X), dim=-1)
return X
Callbacks are a powerful and flexible way to customize the behavior of your neural network. They are all called at specific points during the model training, e.g. when training starts, or after each batch. Have a look at the skorch.callbacks
module to see the callbacks that are already implemented.
Although skorch
comes with a handful of useful callbacks, you may find that you would like to write your own callbacks. Doing so is straightforward, just remember these rules:
skorch.callbacks.Callback
.on_
-methods provided by the parent class (e.g. on_batch_begin
or on_epoch_end
).on_
-methods first get the NeuralNet
instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have **kwargs
in the signature for potentially unused arguments.initialize
method.Here is an example of a callback that remembers at which epoch the validation accuracy reached a certain value. Then, when training is finished, it calls a mock Twitter API and tweets that epoch. We proceed as follows:
__init__
.initialize
.from skorch.callbacks import Callback
def tweet(msg):
print("~" * 60)
print("*tweet*", msg, "#skorch #pytorch")
print("~" * 60)
class AccuracyTweet(Callback):
def __init__(self, min_accuracy):
self.min_accuracy = min_accuracy
def initialize(self):
self.critical_epoch_ = -1
# This runs after every epoch
def on_epoch_end(self, net, **kwargs):
if self.critical_epoch_ > -1:
return
# look at the validation accuracy of the last epoch
if net.history[-1, 'valid_acc'] >= self.min_accuracy:
self.critical_epoch_ = len(net.history)
# This runs at the end of training
def on_train_end(self, net, **kwargs):
if self.critical_epoch_ < 0:
msg = "Accuracy never reached {} :(".format(self.min_accuracy)
else:
msg = "Accuracy reached {} at epoch {}!!!".format(
self.min_accuracy, self.critical_epoch_)
tweet(msg)
Now we initialize a NeuralNetClassifier
and pass your new callback in a list to the callbacks
argument. After that, we train the model and see what happens.
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=15,
lr=0.02,
warm_start=True,
callbacks=[AccuracyTweet(min_accuracy=0.7)],
)
net.fit(X, y)
epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------ 1 0.6954 0.6000 0.6844 0.1888 2 0.6871 0.5950 0.6820 0.0179 3 0.6826 0.6100 0.6793 0.0228 4 0.6751 0.6100 0.6775 0.0206 5 0.6773 0.6150 0.6754 0.0157 6 0.6722 0.6150 0.6733 0.0162 7 0.6665 0.6200 0.6707 0.0171 8 0.6634 0.6200 0.6685 0.0175 9 0.6662 0.6200 0.6659 0.0186 10 0.6635 0.6500 0.6636 0.0169 11 0.6605 0.6550 0.6616 0.0159 12 0.6605 0.6600 0.6593 0.0166 13 0.6616 0.6650 0.6568 0.0184 14 0.6485 0.6750 0.6546 0.0206 15 0.6464 0.6750 0.6518 0.0167 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *tweet* Accuracy never reached 0.7 :( #skorch #pytorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
Oh no, our model never reached a validation accuracy of 0.7. Let's train some more (this is possible because we set warm_start=True
):
# warm_start starts training from the point training stoped previously.
net.fit(X, y)
16 0.6431 0.6800 0.6491 0.0180 17 0.6406 0.6850 0.6460 0.0166 18 0.6501 0.6900 0.6437 0.0161 19 0.6450 0.6950 0.6410 0.0194 20 0.6330 0.7000 0.6380 0.0266 21 0.6306 0.7100 0.6352 0.0175 22 0.6305 0.7100 0.6319 0.0155 23 0.6329 0.7100 0.6295 0.0171 24 0.6322 0.7150 0.6269 0.0173 25 0.6188 0.7050 0.6241 0.0183 26 0.6163 0.7000 0.6206 0.0183 27 0.6133 0.7050 0.6176 0.0171 28 0.6214 0.7050 0.6150 0.0190 29 0.6099 0.7000 0.6122 0.0176 30 0.6156 0.7000 0.6095 0.0181 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *tweet* Accuracy reached 0.7 at epoch 20!!! #skorch #pytorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
assert net.history[-1, 'valid_acc'] >= 0.7
Finally, the validation score exceeded 0.7. Hooray!
Say you would like to use a learning rate schedule with your neural net, but you don't know what parameters are best for that schedule. Wouldn't it be nice if you could find those parameters with a grid search? With skorch
, this is possible. Below, we show how to access the parameters of your callbacks.
To simplify the access to your callback parameters, it is best if you give your callback a name. This is achieved by passing the callbacks
parameter a list of name, callback tuples, such as:
callbacks=[
('scheduler', LearningRateScheduler)),
...
],
This way, you can access your callbacks using the double underscore semantics (as, for instance, in an sklearn
Pipeline
):
callbacks__scheduler__epoch=50,
So if you would like to perform a grid search on, say, the number of units in the hidden layer and the learning rate schedule, it could look something like this:
param_grid = {
'module__num_units': [50, 100, 150],
'callbacks__scheduler__epoch': [10, 50, 100],
}
Note: If you would like to refresh your knowledge on grid search, look here, here, or in the Basic_Usage notebok.
Below, we show how accessing the callback parameters works our AccuracyTweet
callback:
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=10,
lr=0.1,
warm_start=True,
callbacks=[
('tweet', AccuracyTweet(min_accuracy=0.7)),
],
callbacks__tweet__min_accuracy=0.6,
)
net.fit(X, y)
epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------ 1 0.7003 0.5150 0.6880 0.0190 2 0.6825 0.6250 0.6761 0.0160 3 0.6632 0.6450 0.6665 0.0174 4 0.6545 0.6600 0.6574 0.0253 5 0.6397 0.6450 0.6459 0.0219 6 0.6348 0.6750 0.6370 0.0209 7 0.6239 0.6850 0.6276 0.0178 8 0.6119 0.6950 0.6166 0.0217 9 0.5940 0.7250 0.6113 0.0225 10 0.5908 0.7250 0.6017 0.0205 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *tweet* Accuracy reached 0.6 at epoch 2!!! #skorch #pytorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
As you can see, by passing callbacks__tweet__min_accuracy=0.6
, we changed that parameter. The same can be achieved by calling the set_params
method with the corresponding arguments:
net.set_params(callbacks__tweet__min_accuracy=0.75)
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
net.fit(X, y)
epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------ 11 0.5809 0.7300 0.5908 0.0186 12 0.5580 0.7000 0.5864 0.0184 13 0.5604 0.7250 0.5752 0.0166 14 0.5514 0.7200 0.5673 0.0159 15 0.5444 0.7200 0.5599 0.0205 16 0.5467 0.7300 0.5511 0.0240 17 0.5246 0.7350 0.5460 0.0158 18 0.5498 0.7200 0.5428 0.0556 19 0.5197 0.7350 0.5407 0.0624 20 0.5159 0.7350 0.5355 0.0804 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *tweet* Accuracy never reached 0.75 :( #skorch #pytorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
Dataset
s¶We encourage you to not pass Dataset
s to net.fit
but to let skorch handle Dataset
s internally. Nonetheless, there are situations where passing Dataset
s to net.fit
is hard to avoid (e.g. if you want to load the data lazily during the training). This is supported by skorch but may have some unwanted side-effects relating to sklearn. For instance, Dataset
s cannot split into train and validation in a stratified fashion without explicit knowledge of the classification targets.
Below we show what happens when you try to fit with Dataset
and the stratified split fails:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
assert len(X) == len(y)
def __len__(self):
return len(self.X)
def __getitem__(self, i):
return self.X[i], self.y[i]
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
dataset = MyDataset(X, y)
net = NeuralNetClassifier(ClassifierModule)
try:
net.fit(dataset, y=None)
except ValueError as e:
print("Error:", e)
Error: Stratified CV requires explicitly passing a suitable y.
net.train_split.stratified
True
As you can see, the stratified split fails since y
is not known. There are two solutions to this:
net.train_split.stratified=False
)y
explicitly (if possible), even if it is implicitely contained in the Dataset
The second solution is shown below:
net.fit(dataset, y=y)
Re-initializing module. Re-initializing criterion. Re-initializing optimizer. epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------ 1 0.6994 0.4600 0.7100 0.0160 2 0.6960 0.4750 0.7072 0.0302 3 0.6983 0.4900 0.7046 0.0148 4 0.6945 0.5050 0.7023 0.0139 5 0.6859 0.5050 0.7000 0.0146 6 0.6834 0.5300 0.6978 0.0148 7 0.6799 0.5450 0.6960 0.0163 8 0.6734 0.5500 0.6942 0.0161 9 0.6743 0.5450 0.6923 0.0170 10 0.6666 0.5500 0.6906 0.0166
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierModule( (dense0): Linear(in_features=20, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) (dense1): Linear(in_features=10, out_features=10, bias=True) (output): Linear(in_features=10, out_features=2, bias=True) ), )
skorch has built-in support for dictionaries as data containers. Here we show a somewhat contrived example of how to use dicts, but it should get the point across. First we create data and put it into a dictionary X_dict
with two keys X0
and X1
:
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X0, X1 = X[:, :10], X[:, 10:]
X_dict = {'X0': X0, 'X1': X1}
When skorch passes the dict to the pytorch module, it will pass the data as keyword arguments to the forward call. That means that we should accept the two keys XO
and X1
in the forward method, as shown below:
class ClassifierWithDict(nn.Module):
def __init__(
self,
num_units0=50,
num_units1=50,
nonlin=F.relu,
dropout=0.5,
):
super(ClassifierWithDict, self).__init__()
self.num_units0 = num_units0
self.num_units1 = num_units1
self.nonlin = nonlin
self.dropout = dropout
self.dense0 = nn.Linear(10, num_units0)
self.dense1 = nn.Linear(10, num_units1)
self.nonlin = nonlin
self.dropout = nn.Dropout(dropout)
self.output = nn.Linear(num_units0 + num_units1, 2)
# NOTE: We accept X0 and X1, the keys from the dict, as arguments
def forward(self, X0, X1, **kwargs):
X0 = self.nonlin(self.dense0(X0))
X0 = self.dropout(X0)
X1 = self.nonlin(self.dense1(X1))
X1 = self.dropout(X1)
X = torch.cat((X0, X1), dim=1)
X = F.relu(X)
X = F.softmax(self.output(X), dim=-1)
return X
As long as we keep this in mind, we are good to go.
net = NeuralNetClassifier(ClassifierWithDict, verbose=0)
net.fit(X_dict, y)
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierWithDict( (dense0): Linear(in_features=10, out_features=50, bias=True) (dense1): Linear(in_features=10, out_features=50, bias=True) (dropout): Dropout(p=0.5, inplace=False) (output): Linear(in_features=100, out_features=2, bias=True) ), )
Pipeline
and GridSearchCV
¶from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import GridSearchCV
sklearn makes the assumption that incoming data should be numpy/sparse arrays or something similar. This clashes with the use of dictionaries. Unfortunately, it is sometimes impossible to work around that for now (for instance using skorch with BaggingClassifier
). Other times, there are possibilities.
When we have a preprocessing pipeline that involves FunctionTransformer
, we have to pass the parameter validate=False
(which is the default value now) so that sklearn allows the dictionary to pass through. Everything else works:
pipe = Pipeline([
('do-nothing', FunctionTransformer(validate=False)),
('net', net),
])
pipe.fit(X_dict, y)
Pipeline(steps=[('do-nothing', FunctionTransformer()), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierWithDict( (dense0): Linear(in_features=10, out_features=50, bias=True) (dense1): Linear(in_features=10, out_features=50, bias=True) (dropout): Dropout(p=0.5, inplace=False) (output): Linear(in_features=100, out_features=2, bias=True) ), ))])
When trying a grid or randomized search, it is not that easy to pass a dict. If we try, we will get an error:
param_grid = {
'net__module__num_units0': [10, 25, 50],
'net__module__num_units1': [10, 25, 50],
'net__lr': [0.01, 0.1],
}
grid_search = GridSearchCV(pipe, param_grid, scoring='accuracy', verbose=1, cv=3)
try:
grid_search.fit(X_dict, y)
except Exception as e:
print(e)
Found input variables with inconsistent numbers of samples: [2, 1000]
The error above occurs because sklearn gets the length of the input data, which is 2 for the dict, and believes that is inconsistent with the length of the target (1000).
To get around that, skorch provides a helper class called SliceDict
. It allows us to wrap our dictionaries so that they also behave like a numpy array:
from skorch.helper import SliceDict
X_slice_dict = SliceDict(X0=X0, X1=X1) # X_slice_dict = SliceDict(**X_dict) would also work
The SliceDict shows the correct length, shape, and is sliceable across values:
print("Length of dict: {}, length of SliceDict: {}".format(len(X_dict), len(X_slice_dict)))
print("Shape of SliceDict: {}".format(X_slice_dict.shape))
Length of dict: 2, length of SliceDict: 1000 Shape of SliceDict: (1000,)
print("Slicing the SliceDict slices across values: {}".format(X_slice_dict[:2]))
Slicing the SliceDict slices across values: SliceDict(**{'X0': array([[-0.9658346 , -2.1890705 , 0.16985609, 0.8138456 , -3.375209 , -2.1430597 , -0.39585084, 2.9419577 , -2.1910605 , 1.2443967 ], [-0.454767 , 4.339768 , -0.48572844, -4.88433 , -2.8836503 , 2.6097205 , -1.952876 , -0.09192174, 0.07970932, -0.08938338]], dtype=float32), 'X1': array([[ 0.04351204, -0.5150961 , -0.86073655, -1.1097169 , 0.31839254, -0.8231973 , -1.056304 , -0.89645284, 0.3759244 , -1.0849651 ], [-0.60726726, -1.0674309 , 0.48804346, -0.50230557, 0.55743027, 1.01592 , -1.9953582 , 2.9030426 , -0.9739298 , 2.1753323 ]], dtype=float32)})
With this, we can call GridSearchCV
just as expected:
grid_search.fit(X_slice_dict, y)
Fitting 3 folds for each of 18 candidates, totalling 54 fits
GridSearchCV(cv=3, estimator=Pipeline(steps=[('do-nothing', FunctionTransformer()), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=ClassifierWithDict( (dense0): Linear(in_features=10, out_features=50, bias=True) (dense1): Linear(in_features=10, out_features=50, bias=True) (dropout): Dropout(p=0.5, inplace=False) (output): Linear(in_features=100, out_features=2, bias=True) ), ))]), param_grid={'net__lr': [0.01, 0.1], 'net__module__num_units0': [10, 25, 50], 'net__module__num_units1': [10, 25, 50]}, scoring='accuracy', verbose=1)
grid_search.best_score_, grid_search.best_params_
(0.7429825034615454, {'net__lr': 0.1, 'net__module__num_units0': 50, 'net__module__num_units1': 50})
forward
¶Often, we want our Module.forward
method to return more than just one value. There can be several reasons for this. Maybe, the criterion requires not one but several outputs. Or perhaps we want to inspect intermediate values to learn more about our model (say inspecting attention in a sequence-to-sequence model). Fortunately, skorch
makes it easy to achieve this. In the following, we demonstrate how to handle multiple outputs from the Module
.
To demonstrate this, we implement a very simple autoencoder. It consists of an encoder that reduces our input of 20 units to 5 units using two linear layers, and a decoder that tries to reconstruct the original input, again using two linear layers.
from skorch import NeuralNetRegressor
class Encoder(nn.Module):
def __init__(self, num_units=5):
super().__init__()
self.num_units = num_units
self.encode = nn.Sequential(
nn.Linear(20, 10),
nn.ReLU(),
nn.Linear(10, self.num_units),
nn.ReLU(),
)
def forward(self, X):
encoded = self.encode(X)
return encoded
class Decoder(nn.Module):
def __init__(self, num_units):
super().__init__()
self.num_units = num_units
self.decode = nn.Sequential(
nn.Linear(self.num_units, 10),
nn.ReLU(),
nn.Linear(10, 20),
)
def forward(self, X):
decoded = self.decode(X)
return decoded
The autoencoder module below actually returns a tuple of two values, the decoded input and the encoded input. This way, we cannot only use the decoded input to calculate the normal loss but also have access to the encoded state.
class AutoEncoder(nn.Module):
def __init__(self, num_units):
super().__init__()
self.num_units = num_units
self.encoder = Encoder(num_units=self.num_units)
self.decoder = Decoder(num_units=self.num_units)
def forward(self, X):
encoded = self.encoder(X)
decoded = self.decoder(encoded)
return decoded, encoded # <- return a tuple of two values
Since the module's forward
method returns two values, we have to adjust our objective to do the right thing with those values. If we don't do this, the criterion wouldn't know what to do with the two values and would raise an error.
One strategy would be to only use the decoded state for the loss and discard the encoded state. For this demonstration, we have a different plan: We would like the encoded state to be sparse. Therefore, we add an L1 loss of the encoded state to the reconstruction loss. This way, the net will try to reconstruct the input as accurately as possible while keeping the encoded state as sparse as possible.
To implement this, the right method to override is called get_loss
, which is where skorch
computes and returns the loss. It gets the prediction (our tuple) and the target as input, as well as other arguments and keywords that we pass through. We create a subclass of NeuralNetRegressor
that overrides said method and implements our idea for the loss.
class AutoEncoderNet(NeuralNetRegressor):
def get_loss(self, y_pred, y_true, *args, **kwargs):
decoded, encoded = y_pred # <- unpack the tuple that was returned by `forward`
loss_reconstruction = super().get_loss(decoded, y_true, *args, **kwargs)
loss_l1 = 1e-3 * torch.abs(encoded).sum()
return loss_reconstruction + loss_l1
Note: Alternatively, we could have used an unaltered NeuralNetRegressor
but implement a custom criterion that is responsible for unpacking the tuple and computing the loss.
Now that everything is ready, we train the model as usual. We initialize our net subclass with the AutoEncoder
module and call the fit
method with X
both as input and as target (since we want to reconstruct the original data):
net = AutoEncoderNet(
AutoEncoder,
module__num_units=5,
lr=0.3,
)
net.fit(X, X)
epoch train_loss valid_loss dur ------- ------------ ------------ ------ 1 3.8021 3.7869 0.0312 2 3.6940 3.7218 0.0244 3 3.6441 3.6828 0.0186 4 3.6145 3.6578 0.0189 5 3.5955 3.6407 0.0195 6 3.5824 3.6276 0.0197 7 3.5714 3.6146 0.0194 8 3.5571 3.5906 0.0192 9 3.5160 3.4825 0.0285 10 3.3439 3.2388 0.0176
<class '__main__.AutoEncoderNet'>[initialized]( module_=AutoEncoder( (encoder): Encoder( (encode): Sequential( (0): Linear(in_features=20, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() ) ) (decoder): Decoder( (decode): Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=20, bias=True) ) ) ), )
Voilà, the model was trained using our custom loss function that makes use of both predicted values.
Sometimes, we may wish to inspect all the values returned by the foward
method of the module. There are several ways to achieve this. In theory, we can always access the module directly by using the net.module_
attribute. However, this is unwieldy, since this completely shortcuts the prediction loop, which takes care of important steps like casting numpy
arrays to pytorch
tensors and batching.
Also, we cannot use the predict
method on the net. This method will only return the first output from the forward method, in this case the decoded state. The reason for this is that predict
is part of the sklearn
API, which requires there to be only one output. This is shown below:
y_pred = net.predict(X)
y_pred.shape # only the decoded state is returned
(1000, 20)
However, the net itself provides two methods to retrieve all outputs. The first one is the net.forward
method, which retrieves all the predicted batches from the Module.forward
and concatenates them. Use this to retrieve the complete decoded and encoded state:
decoded_pred, encoded_pred = net.forward(X)
decoded_pred.shape, encoded_pred.shape
(torch.Size([1000, 20]), torch.Size([1000, 5]))
The other method is called net.forward_iter
. It is similar to net.forward
but instead of collecting all the batches, this method is lazy and only yields one batch at a time. This can be especially useful if the output doesn't fit into memory:
for decoded_pred, encoded_pred in net.forward_iter(X):
# do something with each batch
break
decoded_pred.shape, encoded_pred.shape
(torch.Size([128, 20]), torch.Size([128, 5]))
Finally, let's make sure that our initial goal of having a sparse encoded state was met. We check how many activities are close to zero:
torch.isclose(encoded_pred, torch.zeros_like(encoded_pred)).float().mean()
tensor(0.8828)
As we had hoped, the encoded state is quite sparse, with the majority of outpus being 0.