Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.


In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Author: Sebastian Raschka

Python implementation: CPython
Python version       : 3.8.12
IPython version      : 8.0.1

torch: 1.10.1

AlexNet CIFAR-10 Classifier

Network Architecture

References

Imports

In [2]:
import os
import time
import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset

from torchvision import datasets
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

Model Settings

Setting a random seed

I recommend using a function like the following one prior to using dataset loaders and initializing a model if you want to ensure the data is shuffled in the same manner if you rerun this notebook and the model gets the same initial random weights:

In [3]:
def set_all_seeds(seed):
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

Setting cuDNN and PyTorch algorithmic behavior to deterministic

Similar to the set_all_seeds function above, I recommend setting the behavior of PyTorch and cuDNN to deterministic (this is particulary relevant when using GPUs). We can also define a function for that:

In [4]:
def set_deterministic():
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    torch.set_deterministic(True)
In [5]:
##########################
### SETTINGS
##########################

# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 256
NUM_EPOCHS = 40

# Architecture
NUM_CLASSES = 10

# Other
DEVICE = "cuda:0"

set_all_seeds(RANDOM_SEED)

# Deterministic behavior not yet supported by AdaptiveAvgPool2d
#set_deterministic()

Import utility functions

In [6]:
import sys

sys.path.insert(0, "..") # to include ../helper_evaluate.py etc.

from helper_evaluate import compute_accuracy
from helper_data import get_dataloaders_cifar10
from helper_train import train_classifier_simple_v1

Dataset

In [7]:
### Set random seed ###
set_all_seeds(RANDOM_SEED)

##########################
### Dataset
##########################

train_transforms = transforms.Compose([transforms.Resize((70, 70)),
                                       transforms.RandomCrop((64, 64)),
                                       transforms.ToTensor()])

test_transforms = transforms.Compose([transforms.Resize((70, 70)),
                                      transforms.CenterCrop((64, 64)),
                                      transforms.ToTensor()])


train_loader, valid_loader, test_loader = get_dataloaders_cifar10(
    batch_size=BATCH_SIZE, 
    num_workers=2, 
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    validation_fraction=0.1)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
Extracting data/cifar-10-python.tar.gz to data
In [8]:
# Checking the dataset
print('Training Set:\n')
for images, labels in train_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break
    
# Checking the dataset
print('\nValidation Set:')
for images, labels in valid_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break

# Checking the dataset
print('\nTesting Set:')
for images, labels in train_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break
Training Set:

Image batch dimensions: torch.Size([256, 3, 64, 64])
Image label dimensions: torch.Size([256])
tensor([0, 2, 3, 5, 4, 8, 9, 6, 9, 7])

Validation Set:
Image batch dimensions: torch.Size([256, 3, 64, 64])
Image label dimensions: torch.Size([256])
tensor([6, 9, 3, 5, 7, 3, 4, 1, 8, 0])

Testing Set:
Image batch dimensions: torch.Size([256, 3, 64, 64])
Image label dimensions: torch.Size([256])
tensor([2, 6, 3, 1, 1, 1, 1, 2, 4, 8])

Model

In [9]:
##########################
### MODEL
##########################

class AlexNet(nn.Module):

    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        logits = self.classifier(x)
        probas = F.softmax(logits, dim=1)
        return logits
In [10]:
torch.manual_seed(RANDOM_SEED)

model = AlexNet(NUM_CLASSES)
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  

Training

In [11]:
log_dict = train_classifier_simple_v1(num_epochs=NUM_EPOCHS, model=model, 
                                      optimizer=optimizer, device=DEVICE, 
                                      train_loader=train_loader, valid_loader=valid_loader, 
                                      logging_interval=50)
Epoch: 001/040 | Batch 0000/0175 | Loss: 2.3033
Epoch: 001/040 | Batch 0050/0175 | Loss: 2.0240
Epoch: 001/040 | Batch 0100/0175 | Loss: 1.9445
Epoch: 001/040 | Batch 0150/0175 | Loss: 1.8135
***Epoch: 001/040 | Train. Acc.: 33.674% | Loss: 1.703
***Epoch: 001/040 | Valid. Acc.: 34.880% | Loss: 1.670
Time elapsed: 1.05 min
Epoch: 002/040 | Batch 0000/0175 | Loss: 1.7606
Epoch: 002/040 | Batch 0050/0175 | Loss: 1.5473
Epoch: 002/040 | Batch 0100/0175 | Loss: 1.5496
Epoch: 002/040 | Batch 0150/0175 | Loss: 1.5093
***Epoch: 002/040 | Train. Acc.: 42.819% | Loss: 1.505
***Epoch: 002/040 | Valid. Acc.: 43.840% | Loss: 1.491
Time elapsed: 2.09 min
Epoch: 003/040 | Batch 0000/0175 | Loss: 1.5411
Epoch: 003/040 | Batch 0050/0175 | Loss: 1.5485
Epoch: 003/040 | Batch 0100/0175 | Loss: 1.3723
Epoch: 003/040 | Batch 0150/0175 | Loss: 1.3084
***Epoch: 003/040 | Train. Acc.: 49.712% | Loss: 1.336
***Epoch: 003/040 | Valid. Acc.: 50.300% | Loss: 1.327
Time elapsed: 3.12 min
Epoch: 004/040 | Batch 0000/0175 | Loss: 1.4301
Epoch: 004/040 | Batch 0050/0175 | Loss: 1.4117
Epoch: 004/040 | Batch 0100/0175 | Loss: 1.2894
Epoch: 004/040 | Batch 0150/0175 | Loss: 1.1508
***Epoch: 004/040 | Train. Acc.: 54.138% | Loss: 1.231
***Epoch: 004/040 | Valid. Acc.: 54.300% | Loss: 1.226
Time elapsed: 4.16 min
Epoch: 005/040 | Batch 0000/0175 | Loss: 1.1781
Epoch: 005/040 | Batch 0050/0175 | Loss: 1.2942
Epoch: 005/040 | Batch 0100/0175 | Loss: 1.3343
Epoch: 005/040 | Batch 0150/0175 | Loss: 1.1216
***Epoch: 005/040 | Train. Acc.: 58.536% | Loss: 1.139
***Epoch: 005/040 | Valid. Acc.: 58.220% | Loss: 1.152
Time elapsed: 5.23 min
Epoch: 006/040 | Batch 0000/0175 | Loss: 1.1030
Epoch: 006/040 | Batch 0050/0175 | Loss: 1.1732
Epoch: 006/040 | Batch 0100/0175 | Loss: 1.1508
Epoch: 006/040 | Batch 0150/0175 | Loss: 1.0059
***Epoch: 006/040 | Train. Acc.: 58.882% | Loss: 1.132
***Epoch: 006/040 | Valid. Acc.: 58.600% | Loss: 1.158
Time elapsed: 6.28 min
Epoch: 007/040 | Batch 0000/0175 | Loss: 1.0091
Epoch: 007/040 | Batch 0050/0175 | Loss: 1.2888
Epoch: 007/040 | Batch 0100/0175 | Loss: 1.0148
Epoch: 007/040 | Batch 0150/0175 | Loss: 1.0491
***Epoch: 007/040 | Train. Acc.: 65.203% | Loss: 0.966
***Epoch: 007/040 | Valid. Acc.: 63.880% | Loss: 1.007
Time elapsed: 7.31 min
Epoch: 008/040 | Batch 0000/0175 | Loss: 0.8920
Epoch: 008/040 | Batch 0050/0175 | Loss: 0.9769
Epoch: 008/040 | Batch 0100/0175 | Loss: 1.0159
Epoch: 008/040 | Batch 0150/0175 | Loss: 1.0733
***Epoch: 008/040 | Train. Acc.: 67.181% | Loss: 0.920
***Epoch: 008/040 | Valid. Acc.: 65.020% | Loss: 0.974
Time elapsed: 8.35 min
Epoch: 009/040 | Batch 0000/0175 | Loss: 0.9276
Epoch: 009/040 | Batch 0050/0175 | Loss: 0.8630
Epoch: 009/040 | Batch 0100/0175 | Loss: 1.1130
Epoch: 009/040 | Batch 0150/0175 | Loss: 0.9105
***Epoch: 009/040 | Train. Acc.: 66.795% | Loss: 0.920
***Epoch: 009/040 | Valid. Acc.: 64.980% | Loss: 0.984
Time elapsed: 9.38 min
Epoch: 010/040 | Batch 0000/0175 | Loss: 0.8506
Epoch: 010/040 | Batch 0050/0175 | Loss: 0.7531
Epoch: 010/040 | Batch 0100/0175 | Loss: 0.9312
Epoch: 010/040 | Batch 0150/0175 | Loss: 0.9103
***Epoch: 010/040 | Train. Acc.: 70.491% | Loss: 0.832
***Epoch: 010/040 | Valid. Acc.: 67.560% | Loss: 0.934
Time elapsed: 10.42 min
Epoch: 011/040 | Batch 0000/0175 | Loss: 0.8196
Epoch: 011/040 | Batch 0050/0175 | Loss: 0.7955
Epoch: 011/040 | Batch 0100/0175 | Loss: 0.9367
Epoch: 011/040 | Batch 0150/0175 | Loss: 0.7501
***Epoch: 011/040 | Train. Acc.: 70.826% | Loss: 0.819
***Epoch: 011/040 | Valid. Acc.: 66.220% | Loss: 0.950
Time elapsed: 11.50 min
Epoch: 012/040 | Batch 0000/0175 | Loss: 0.7863
Epoch: 012/040 | Batch 0050/0175 | Loss: 0.8496
Epoch: 012/040 | Batch 0100/0175 | Loss: 0.7997
Epoch: 012/040 | Batch 0150/0175 | Loss: 0.9733
***Epoch: 012/040 | Train. Acc.: 73.600% | Loss: 0.757
***Epoch: 012/040 | Valid. Acc.: 68.640% | Loss: 0.901
Time elapsed: 12.72 min
Epoch: 013/040 | Batch 0000/0175 | Loss: 0.8286
Epoch: 013/040 | Batch 0050/0175 | Loss: 0.8397
Epoch: 013/040 | Batch 0100/0175 | Loss: 0.7478
Epoch: 013/040 | Batch 0150/0175 | Loss: 0.8451
***Epoch: 013/040 | Train. Acc.: 76.750% | Loss: 0.672
***Epoch: 013/040 | Valid. Acc.: 70.260% | Loss: 0.848
Time elapsed: 13.96 min
Epoch: 014/040 | Batch 0000/0175 | Loss: 0.6818
Epoch: 014/040 | Batch 0050/0175 | Loss: 0.7883
Epoch: 014/040 | Batch 0100/0175 | Loss: 0.7845
Epoch: 014/040 | Batch 0150/0175 | Loss: 0.6714
***Epoch: 014/040 | Train. Acc.: 76.462% | Loss: 0.669
***Epoch: 014/040 | Valid. Acc.: 69.560% | Loss: 0.876
Time elapsed: 15.17 min
Epoch: 015/040 | Batch 0000/0175 | Loss: 0.7720
Epoch: 015/040 | Batch 0050/0175 | Loss: 0.7569
Epoch: 015/040 | Batch 0100/0175 | Loss: 0.6428
Epoch: 015/040 | Batch 0150/0175 | Loss: 0.7415
***Epoch: 015/040 | Train. Acc.: 78.196% | Loss: 0.622
***Epoch: 015/040 | Valid. Acc.: 70.460% | Loss: 0.852
Time elapsed: 16.39 min
Epoch: 016/040 | Batch 0000/0175 | Loss: 0.6150
Epoch: 016/040 | Batch 0050/0175 | Loss: 0.7300
Epoch: 016/040 | Batch 0100/0175 | Loss: 0.4870
Epoch: 016/040 | Batch 0150/0175 | Loss: 0.6177
***Epoch: 016/040 | Train. Acc.: 80.033% | Loss: 0.571
***Epoch: 016/040 | Valid. Acc.: 71.500% | Loss: 0.832
Time elapsed: 17.62 min
Epoch: 017/040 | Batch 0000/0175 | Loss: 0.6556
Epoch: 017/040 | Batch 0050/0175 | Loss: 0.6564
Epoch: 017/040 | Batch 0100/0175 | Loss: 0.5505
Epoch: 017/040 | Batch 0150/0175 | Loss: 0.6272
***Epoch: 017/040 | Train. Acc.: 81.415% | Loss: 0.532
***Epoch: 017/040 | Valid. Acc.: 71.980% | Loss: 0.836
Time elapsed: 18.80 min
Epoch: 018/040 | Batch 0000/0175 | Loss: 0.5772
Epoch: 018/040 | Batch 0050/0175 | Loss: 0.4951
Epoch: 018/040 | Batch 0100/0175 | Loss: 0.4850
Epoch: 018/040 | Batch 0150/0175 | Loss: 0.6942
***Epoch: 018/040 | Train. Acc.: 82.944% | Loss: 0.486
***Epoch: 018/040 | Valid. Acc.: 71.520% | Loss: 0.839
Time elapsed: 19.84 min
Epoch: 019/040 | Batch 0000/0175 | Loss: 0.4757
Epoch: 019/040 | Batch 0050/0175 | Loss: 0.4909
Epoch: 019/040 | Batch 0100/0175 | Loss: 0.5568
Epoch: 019/040 | Batch 0150/0175 | Loss: 0.5895
***Epoch: 019/040 | Train. Acc.: 81.592% | Loss: 0.515
***Epoch: 019/040 | Valid. Acc.: 70.840% | Loss: 0.911
Time elapsed: 20.87 min
Epoch: 020/040 | Batch 0000/0175 | Loss: 0.5108
Epoch: 020/040 | Batch 0050/0175 | Loss: 0.5133
Epoch: 020/040 | Batch 0100/0175 | Loss: 0.4775
Epoch: 020/040 | Batch 0150/0175 | Loss: 0.5364
***Epoch: 020/040 | Train. Acc.: 85.272% | Loss: 0.431
***Epoch: 020/040 | Valid. Acc.: 72.240% | Loss: 0.850
Time elapsed: 21.89 min
Epoch: 021/040 | Batch 0000/0175 | Loss: 0.4184
Epoch: 021/040 | Batch 0050/0175 | Loss: 0.5490
Epoch: 021/040 | Batch 0100/0175 | Loss: 0.4124
Epoch: 021/040 | Batch 0150/0175 | Loss: 0.3877
***Epoch: 021/040 | Train. Acc.: 86.616% | Loss: 0.384
***Epoch: 021/040 | Valid. Acc.: 72.900% | Loss: 0.850
Time elapsed: 22.93 min
Epoch: 022/040 | Batch 0000/0175 | Loss: 0.3587
Epoch: 022/040 | Batch 0050/0175 | Loss: 0.4164
Epoch: 022/040 | Batch 0100/0175 | Loss: 0.4908
Epoch: 022/040 | Batch 0150/0175 | Loss: 0.5300
***Epoch: 022/040 | Train. Acc.: 87.763% | Loss: 0.353
***Epoch: 022/040 | Valid. Acc.: 73.160% | Loss: 0.867
Time elapsed: 23.97 min
Epoch: 023/040 | Batch 0000/0175 | Loss: 0.3409
Epoch: 023/040 | Batch 0050/0175 | Loss: 0.3932
Epoch: 023/040 | Batch 0100/0175 | Loss: 0.4906
Epoch: 023/040 | Batch 0150/0175 | Loss: 0.3842
***Epoch: 023/040 | Train. Acc.: 88.516% | Loss: 0.325
***Epoch: 023/040 | Valid. Acc.: 72.940% | Loss: 0.867
Time elapsed: 25.01 min
Epoch: 024/040 | Batch 0000/0175 | Loss: 0.3903
Epoch: 024/040 | Batch 0050/0175 | Loss: 0.4127
Epoch: 024/040 | Batch 0100/0175 | Loss: 0.3478
Epoch: 024/040 | Batch 0150/0175 | Loss: 0.4306
***Epoch: 024/040 | Train. Acc.: 90.315% | Loss: 0.284
***Epoch: 024/040 | Valid. Acc.: 73.220% | Loss: 0.911
Time elapsed: 26.04 min
Epoch: 025/040 | Batch 0000/0175 | Loss: 0.2716
Epoch: 025/040 | Batch 0050/0175 | Loss: 0.3371
Epoch: 025/040 | Batch 0100/0175 | Loss: 0.4309
Epoch: 025/040 | Batch 0150/0175 | Loss: 0.4343
***Epoch: 025/040 | Train. Acc.: 88.908% | Loss: 0.311
***Epoch: 025/040 | Valid. Acc.: 73.000% | Loss: 0.909
Time elapsed: 27.07 min
Epoch: 026/040 | Batch 0000/0175 | Loss: 0.2467
Epoch: 026/040 | Batch 0050/0175 | Loss: 0.2832
Epoch: 026/040 | Batch 0100/0175 | Loss: 0.3431
Epoch: 026/040 | Batch 0150/0175 | Loss: 0.3218
***Epoch: 026/040 | Train. Acc.: 90.547% | Loss: 0.272
***Epoch: 026/040 | Valid. Acc.: 72.900% | Loss: 0.925
Time elapsed: 28.10 min
Epoch: 027/040 | Batch 0000/0175 | Loss: 0.3064
Epoch: 027/040 | Batch 0050/0175 | Loss: 0.2874
Epoch: 027/040 | Batch 0100/0175 | Loss: 0.3545
Epoch: 027/040 | Batch 0150/0175 | Loss: 0.3866
***Epoch: 027/040 | Train. Acc.: 92.277% | Loss: 0.230
***Epoch: 027/040 | Valid. Acc.: 73.760% | Loss: 0.935
Time elapsed: 29.13 min
Epoch: 028/040 | Batch 0000/0175 | Loss: 0.1964
Epoch: 028/040 | Batch 0050/0175 | Loss: 0.2317
Epoch: 028/040 | Batch 0100/0175 | Loss: 0.2595
Epoch: 028/040 | Batch 0150/0175 | Loss: 0.3056
***Epoch: 028/040 | Train. Acc.: 92.049% | Loss: 0.225
***Epoch: 028/040 | Valid. Acc.: 73.340% | Loss: 0.994
Time elapsed: 30.16 min
Epoch: 029/040 | Batch 0000/0175 | Loss: 0.2118
Epoch: 029/040 | Batch 0050/0175 | Loss: 0.2198
Epoch: 029/040 | Batch 0100/0175 | Loss: 0.2389
Epoch: 029/040 | Batch 0150/0175 | Loss: 0.3052
***Epoch: 029/040 | Train. Acc.: 93.170% | Loss: 0.198
***Epoch: 029/040 | Valid. Acc.: 73.520% | Loss: 1.004
Time elapsed: 31.20 min
Epoch: 030/040 | Batch 0000/0175 | Loss: 0.1664
Epoch: 030/040 | Batch 0050/0175 | Loss: 0.1880
Epoch: 030/040 | Batch 0100/0175 | Loss: 0.1938
Epoch: 030/040 | Batch 0150/0175 | Loss: 0.2032
***Epoch: 030/040 | Train. Acc.: 93.333% | Loss: 0.188
***Epoch: 030/040 | Valid. Acc.: 72.820% | Loss: 1.061
Time elapsed: 32.23 min
Epoch: 031/040 | Batch 0000/0175 | Loss: 0.2679
Epoch: 031/040 | Batch 0050/0175 | Loss: 0.2778
Epoch: 031/040 | Batch 0100/0175 | Loss: 0.2026
Epoch: 031/040 | Batch 0150/0175 | Loss: 0.2144
***Epoch: 031/040 | Train. Acc.: 94.058% | Loss: 0.170
***Epoch: 031/040 | Valid. Acc.: 73.500% | Loss: 1.044
Time elapsed: 33.27 min
Epoch: 032/040 | Batch 0000/0175 | Loss: 0.1634
Epoch: 032/040 | Batch 0050/0175 | Loss: 0.2475
Epoch: 032/040 | Batch 0100/0175 | Loss: 0.1528
Epoch: 032/040 | Batch 0150/0175 | Loss: 0.2810
***Epoch: 032/040 | Train. Acc.: 94.471% | Loss: 0.161
***Epoch: 032/040 | Valid. Acc.: 73.000% | Loss: 1.065
Time elapsed: 34.32 min
Epoch: 033/040 | Batch 0000/0175 | Loss: 0.2095
Epoch: 033/040 | Batch 0050/0175 | Loss: 0.1590
Epoch: 033/040 | Batch 0100/0175 | Loss: 0.1752
Epoch: 033/040 | Batch 0150/0175 | Loss: 0.2319
***Epoch: 033/040 | Train. Acc.: 95.118% | Loss: 0.141
***Epoch: 033/040 | Valid. Acc.: 73.440% | Loss: 1.075
Time elapsed: 35.35 min
Epoch: 034/040 | Batch 0000/0175 | Loss: 0.1156
Epoch: 034/040 | Batch 0050/0175 | Loss: 0.1456
Epoch: 034/040 | Batch 0100/0175 | Loss: 0.1519
Epoch: 034/040 | Batch 0150/0175 | Loss: 0.1831
***Epoch: 034/040 | Train. Acc.: 95.286% | Loss: 0.142
***Epoch: 034/040 | Valid. Acc.: 73.820% | Loss: 1.064
Time elapsed: 36.40 min
Epoch: 035/040 | Batch 0000/0175 | Loss: 0.1532
Epoch: 035/040 | Batch 0050/0175 | Loss: 0.1267
Epoch: 035/040 | Batch 0100/0175 | Loss: 0.1633
Epoch: 035/040 | Batch 0150/0175 | Loss: 0.1263
***Epoch: 035/040 | Train. Acc.: 94.638% | Loss: 0.154
***Epoch: 035/040 | Valid. Acc.: 73.440% | Loss: 1.093
Time elapsed: 37.43 min
Epoch: 036/040 | Batch 0000/0175 | Loss: 0.1787
Epoch: 036/040 | Batch 0050/0175 | Loss: 0.1622
Epoch: 036/040 | Batch 0100/0175 | Loss: 0.1840
Epoch: 036/040 | Batch 0150/0175 | Loss: 0.1143
***Epoch: 036/040 | Train. Acc.: 95.480% | Loss: 0.132
***Epoch: 036/040 | Valid. Acc.: 73.020% | Loss: 1.159
Time elapsed: 38.46 min
Epoch: 037/040 | Batch 0000/0175 | Loss: 0.1282
Epoch: 037/040 | Batch 0050/0175 | Loss: 0.1299
Epoch: 037/040 | Batch 0100/0175 | Loss: 0.1869
Epoch: 037/040 | Batch 0150/0175 | Loss: 0.1387
***Epoch: 037/040 | Train. Acc.: 95.129% | Loss: 0.138
***Epoch: 037/040 | Valid. Acc.: 72.640% | Loss: 1.174
Time elapsed: 39.50 min
Epoch: 038/040 | Batch 0000/0175 | Loss: 0.1137
Epoch: 038/040 | Batch 0050/0175 | Loss: 0.1053
Epoch: 038/040 | Batch 0100/0175 | Loss: 0.1298
Epoch: 038/040 | Batch 0150/0175 | Loss: 0.1280
***Epoch: 038/040 | Train. Acc.: 95.429% | Loss: 0.134
***Epoch: 038/040 | Valid. Acc.: 73.040% | Loss: 1.230
Time elapsed: 40.53 min
Epoch: 039/040 | Batch 0000/0175 | Loss: 0.1410
Epoch: 039/040 | Batch 0050/0175 | Loss: 0.1084
Epoch: 039/040 | Batch 0100/0175 | Loss: 0.1578
Epoch: 039/040 | Batch 0150/0175 | Loss: 0.1516
***Epoch: 039/040 | Train. Acc.: 97.002% | Loss: 0.090
***Epoch: 039/040 | Valid. Acc.: 73.420% | Loss: 1.159
Time elapsed: 41.57 min
Epoch: 040/040 | Batch 0000/0175 | Loss: 0.1143
Epoch: 040/040 | Batch 0050/0175 | Loss: 0.1153
Epoch: 040/040 | Batch 0100/0175 | Loss: 0.1493
Epoch: 040/040 | Batch 0150/0175 | Loss: 0.2771
***Epoch: 040/040 | Train. Acc.: 96.071% | Loss: 0.111
***Epoch: 040/040 | Valid. Acc.: 73.520% | Loss: 1.218
Time elapsed: 42.61 min
Total Training Time: 42.61 min

Evaluation

In [12]:
import matplotlib.pyplot as plt
%matplotlib inline
In [14]:
loss_list = log_dict['train_loss_per_batch']

plt.plot(loss_list, label='Minibatch loss')
plt.plot(np.convolve(loss_list, 
                     np.ones(200,)/200, mode='valid'), 
         label='Running average')

plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
In [17]:
plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['train_acc_per_epoch'], label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['valid_acc_per_epoch'], label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
In [18]:
with torch.set_grad_enabled(False):
    
    train_acc = compute_accuracy(model=model,
                                 data_loader=test_loader,
                                 device=DEVICE)
    
    test_acc = compute_accuracy(model=model,
                                data_loader=test_loader,
                                device=DEVICE)
    
    valid_acc = compute_accuracy(model=model,
                                 data_loader=valid_loader,
                                 device=DEVICE)
    

print(f'Train ACC: {valid_acc:.2f}%')
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Train ACC: 73.52%
Validation ACC: 73.52%
Test ACC: 72.11%
In [19]:
%watermark -iv
sys        : 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51) 
[GCC 9.4.0]
matplotlib : 3.3.4
PIL        : 9.0.1
torchvision: 0.11.2
numpy      : 1.22.0
torch      : 1.10.1
pandas     : 1.4.1

In [ ]: