import os
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split
# Continue with regular imports
import matplotlib.pyplot as plt
# Try to get torchinfo for summary of model, install it if it doesn't work
try:
from torchinfo import summary
except:
!pip install -q torchinfo
from torchinfo import summary
from utils.accuracy import accuracy
from utils.train_test import *
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cuda'
# All models are pretrained on IMAGENET1K
train_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.RandomRotation(degrees=(-15, 15)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
DATA_DIR = 'brain_dataset'
SPLIT_RATIO = 0.8
# Import dataset
dataset = datasets.ImageFolder(root=DATA_DIR)
# Split dataset
train_size = int(SPLIT_RATIO * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
#Assign transforms
train_dataset.dataset.transform = train_transform
test_dataset.dataset.transform = test_transform
classes = dataset.classes
len(train_dataset), len(test_dataset), classes
(201, 51, ['no', 'yes'])
BATCH_SIZE = 32
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())
train_dataloader, test_dataloader
(<torch.utils.data.dataloader.DataLoader at 0x29416692450>, <torch.utils.data.dataloader.DataLoader at 0x29415f7ac10>)
fig = plt.figure(figsize=(9, 9))
rows, cols = 4, 4
for i in range(1, rows*cols+1):
random_idx = torch.randint(0, len(train_dataset), size=[1]).item()
image, true_label = train_dataset[random_idx]
fig.add_subplot(rows, cols, i)
plt.imshow(image.permute(1, 2, 0), cmap="gray")
plt.title(f"{classes[true_label]}", fontsize=8)
plt.axis(False)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
weights = torchvision.models.AlexNet_Weights.IMAGENET1K_V1
model = torchvision.models.alexnet(weights=weights)
model.to(device)
model.classifier
Sequential( (0): Dropout(p=0.5, inplace=False) (1): Linear(in_features=9216, out_features=4096, bias=True) (2): ReLU(inplace=True) (3): Dropout(p=0.5, inplace=False) (4): Linear(in_features=4096, out_features=4096, bias=True) (5): ReLU(inplace=True) (6): Linear(in_features=4096, out_features=1000, bias=True) )
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
======================================================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ======================================================================================================================== AlexNet (AlexNet) [1, 3, 224, 224] [1, 1000] -- True ├─Sequential (features) [1, 3, 224, 224] [1, 256, 6, 6] -- True │ └─Conv2d (0) [1, 3, 224, 224] [1, 64, 55, 55] 23,296 True │ └─ReLU (1) [1, 64, 55, 55] [1, 64, 55, 55] -- -- │ └─MaxPool2d (2) [1, 64, 55, 55] [1, 64, 27, 27] -- -- │ └─Conv2d (3) [1, 64, 27, 27] [1, 192, 27, 27] 307,392 True │ └─ReLU (4) [1, 192, 27, 27] [1, 192, 27, 27] -- -- │ └─MaxPool2d (5) [1, 192, 27, 27] [1, 192, 13, 13] -- -- │ └─Conv2d (6) [1, 192, 13, 13] [1, 384, 13, 13] 663,936 True │ └─ReLU (7) [1, 384, 13, 13] [1, 384, 13, 13] -- -- │ └─Conv2d (8) [1, 384, 13, 13] [1, 256, 13, 13] 884,992 True │ └─ReLU (9) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─Conv2d (10) [1, 256, 13, 13] [1, 256, 13, 13] 590,080 True │ └─ReLU (11) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─MaxPool2d (12) [1, 256, 13, 13] [1, 256, 6, 6] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 256, 6, 6] [1, 256, 6, 6] -- -- ├─Sequential (classifier) [1, 9216] [1, 1000] -- True │ └─Dropout (0) [1, 9216] [1, 9216] -- -- │ └─Linear (1) [1, 9216] [1, 4096] 37,752,832 True │ └─ReLU (2) [1, 4096] [1, 4096] -- -- │ └─Dropout (3) [1, 4096] [1, 4096] -- -- │ └─Linear (4) [1, 4096] [1, 4096] 16,781,312 True │ └─ReLU (5) [1, 4096] [1, 4096] -- -- │ └─Linear (6) [1, 4096] [1, 1000] 4,097,000 True ======================================================================================================================== Total params: 61,100,840 Trainable params: 61,100,840 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 714.68 ======================================================================================================================== Input size (MB): 0.60 Forward/backward pass size (MB): 3.95 Params size (MB): 244.40 Estimated Total Size (MB): 248.96 ========================================================================================================================
for param in model.features.parameters():
param.requires_grad = False
for i, param in enumerate(model.classifier.parameters()):
if i<4:
param.requires_grad = False
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
======================================================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ======================================================================================================================== AlexNet (AlexNet) [1, 3, 224, 224] [1, 1000] -- Partial ├─Sequential (features) [1, 3, 224, 224] [1, 256, 6, 6] -- False │ └─Conv2d (0) [1, 3, 224, 224] [1, 64, 55, 55] (23,296) False │ └─ReLU (1) [1, 64, 55, 55] [1, 64, 55, 55] -- -- │ └─MaxPool2d (2) [1, 64, 55, 55] [1, 64, 27, 27] -- -- │ └─Conv2d (3) [1, 64, 27, 27] [1, 192, 27, 27] (307,392) False │ └─ReLU (4) [1, 192, 27, 27] [1, 192, 27, 27] -- -- │ └─MaxPool2d (5) [1, 192, 27, 27] [1, 192, 13, 13] -- -- │ └─Conv2d (6) [1, 192, 13, 13] [1, 384, 13, 13] (663,936) False │ └─ReLU (7) [1, 384, 13, 13] [1, 384, 13, 13] -- -- │ └─Conv2d (8) [1, 384, 13, 13] [1, 256, 13, 13] (884,992) False │ └─ReLU (9) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─Conv2d (10) [1, 256, 13, 13] [1, 256, 13, 13] (590,080) False │ └─ReLU (11) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─MaxPool2d (12) [1, 256, 13, 13] [1, 256, 6, 6] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 256, 6, 6] [1, 256, 6, 6] -- -- ├─Sequential (classifier) [1, 9216] [1, 1000] -- Partial │ └─Dropout (0) [1, 9216] [1, 9216] -- -- │ └─Linear (1) [1, 9216] [1, 4096] (37,752,832) False │ └─ReLU (2) [1, 4096] [1, 4096] -- -- │ └─Dropout (3) [1, 4096] [1, 4096] -- -- │ └─Linear (4) [1, 4096] [1, 4096] (16,781,312) False │ └─ReLU (5) [1, 4096] [1, 4096] -- -- │ └─Linear (6) [1, 4096] [1, 1000] 4,097,000 True ======================================================================================================================== Total params: 61,100,840 Trainable params: 4,097,000 Non-trainable params: 57,003,840 Total mult-adds (Units.MEGABYTES): 714.68 ======================================================================================================================== Input size (MB): 0.60 Forward/backward pass size (MB): 3.95 Params size (MB): 244.40 Estimated Total Size (MB): 248.96 ========================================================================================================================
model.classifier[6] = nn.Linear(4096, 1)
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
======================================================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ======================================================================================================================== AlexNet (AlexNet) [1, 3, 224, 224] [1, 1] -- Partial ├─Sequential (features) [1, 3, 224, 224] [1, 256, 6, 6] -- False │ └─Conv2d (0) [1, 3, 224, 224] [1, 64, 55, 55] (23,296) False │ └─ReLU (1) [1, 64, 55, 55] [1, 64, 55, 55] -- -- │ └─MaxPool2d (2) [1, 64, 55, 55] [1, 64, 27, 27] -- -- │ └─Conv2d (3) [1, 64, 27, 27] [1, 192, 27, 27] (307,392) False │ └─ReLU (4) [1, 192, 27, 27] [1, 192, 27, 27] -- -- │ └─MaxPool2d (5) [1, 192, 27, 27] [1, 192, 13, 13] -- -- │ └─Conv2d (6) [1, 192, 13, 13] [1, 384, 13, 13] (663,936) False │ └─ReLU (7) [1, 384, 13, 13] [1, 384, 13, 13] -- -- │ └─Conv2d (8) [1, 384, 13, 13] [1, 256, 13, 13] (884,992) False │ └─ReLU (9) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─Conv2d (10) [1, 256, 13, 13] [1, 256, 13, 13] (590,080) False │ └─ReLU (11) [1, 256, 13, 13] [1, 256, 13, 13] -- -- │ └─MaxPool2d (12) [1, 256, 13, 13] [1, 256, 6, 6] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 256, 6, 6] [1, 256, 6, 6] -- -- ├─Sequential (classifier) [1, 9216] [1, 1] -- Partial │ └─Dropout (0) [1, 9216] [1, 9216] -- -- │ └─Linear (1) [1, 9216] [1, 4096] (37,752,832) False │ └─ReLU (2) [1, 4096] [1, 4096] -- -- │ └─Dropout (3) [1, 4096] [1, 4096] -- -- │ └─Linear (4) [1, 4096] [1, 4096] (16,781,312) False │ └─ReLU (5) [1, 4096] [1, 4096] -- -- │ └─Linear (6) [1, 4096] [1, 1] 4,097 True ======================================================================================================================== Total params: 57,007,937 Trainable params: 4,097 Non-trainable params: 57,003,840 Total mult-adds (Units.MEGABYTES): 710.59 ======================================================================================================================== Input size (MB): 0.60 Forward/backward pass size (MB): 3.95 Params size (MB): 228.03 Estimated Total Size (MB): 232.58 ========================================================================================================================
loss_fn = nn.BCEWithLogitsLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
from tqdm.auto import tqdm
import time
train_losses = []
test_losses = []
test_accuracies = []
EPOCHS = 30
program_starts = time.time()
for epoch in tqdm(range(EPOCHS)):
train_loss = train_loop(train_dataloader, model, loss_fn, optimizer, device)
test_loss, test_accuracy = test_loop(test_dataloader, model, loss_fn, device)
train_losses.append(train_loss)
test_losses.append(test_loss)
test_accuracies.append(test_accuracy)
print(f"Epoch: {epoch+1} | Training Loss: {train_loss} | Test Loss: {test_loss} | Test acc: {test_accuracy}%")
program_ends = time.time()
print(f"Elapsed: {program_ends-program_starts}s")
0%| | 0/30 [00:00<?, ?it/s]
Epoch: 1 | Training Loss: 0.6329338550567627 | Test Loss: 0.5103597044944763 | Test acc: 76.47058823529412% Epoch: 2 | Training Loss: 0.4774479866027832 | Test Loss: 0.4821828305721283 | Test acc: 72.54901960784314% Epoch: 3 | Training Loss: 0.4524262845516205 | Test Loss: 0.45745745301246643 | Test acc: 78.43137254901961% Epoch: 4 | Training Loss: 0.38824909925460815 | Test Loss: 0.4513978958129883 | Test acc: 80.3921568627451% Epoch: 5 | Training Loss: 0.3458859324455261 | Test Loss: 0.4499729871749878 | Test acc: 82.3529411764706% Epoch: 6 | Training Loss: 0.3082960546016693 | Test Loss: 0.42737245559692383 | Test acc: 84.31372549019608% Epoch: 7 | Training Loss: 0.3031310439109802 | Test Loss: 0.4216597378253937 | Test acc: 84.31372549019608% Epoch: 8 | Training Loss: 0.2722262740135193 | Test Loss: 0.4108864367008209 | Test acc: 86.27450980392157% Epoch: 9 | Training Loss: 0.26732656359672546 | Test Loss: 0.4054611921310425 | Test acc: 86.27450980392157% Epoch: 10 | Training Loss: 0.2754639685153961 | Test Loss: 0.4053954482078552 | Test acc: 86.27450980392157% Epoch: 11 | Training Loss: 0.2346958965063095 | Test Loss: 0.40852028131484985 | Test acc: 86.27450980392157% Epoch: 12 | Training Loss: 0.24143975973129272 | Test Loss: 0.3937559723854065 | Test acc: 86.27450980392157% Epoch: 13 | Training Loss: 0.2435234934091568 | Test Loss: 0.38772913813591003 | Test acc: 86.27450980392157% Epoch: 14 | Training Loss: 0.19618511199951172 | Test Loss: 0.3964950144290924 | Test acc: 86.27450980392157% Epoch: 15 | Training Loss: 0.23904776573181152 | Test Loss: 0.4051493704319 | Test acc: 84.31372549019608% Epoch: 16 | Training Loss: 0.1947181671857834 | Test Loss: 0.3928598165512085 | Test acc: 86.27450980392157% Epoch: 17 | Training Loss: 0.2019444853067398 | Test Loss: 0.38535916805267334 | Test acc: 86.27450980392157% Epoch: 18 | Training Loss: 0.21457834541797638 | Test Loss: 0.40606990456581116 | Test acc: 86.27450980392157% Epoch: 19 | Training Loss: 0.21077598631381989 | Test Loss: 0.39269891381263733 | Test acc: 86.27450980392157% Epoch: 20 | Training Loss: 0.20026259124279022 | Test Loss: 0.387015700340271 | Test acc: 86.27450980392157% Epoch: 21 | Training Loss: 0.19900424778461456 | Test Loss: 0.3949466049671173 | Test acc: 86.27450980392157% Epoch: 22 | Training Loss: 0.20062680542469025 | Test Loss: 0.385182648897171 | Test acc: 86.27450980392157% Epoch: 23 | Training Loss: 0.19526785612106323 | Test Loss: 0.38639727234840393 | Test acc: 86.27450980392157% Epoch: 24 | Training Loss: 0.20533517003059387 | Test Loss: 0.3713395893573761 | Test acc: 86.27450980392157% Epoch: 25 | Training Loss: 0.1718456745147705 | Test Loss: 0.3612331748008728 | Test acc: 88.23529411764706% Epoch: 26 | Training Loss: 0.155961811542511 | Test Loss: 0.3814121186733246 | Test acc: 86.27450980392157% Epoch: 27 | Training Loss: 0.17319394648075104 | Test Loss: 0.3787640631198883 | Test acc: 86.27450980392157% Epoch: 28 | Training Loss: 0.15943369269371033 | Test Loss: 0.37068402767181396 | Test acc: 86.27450980392157% Epoch: 29 | Training Loss: 0.17207781970500946 | Test Loss: 0.35953351855278015 | Test acc: 88.23529411764706% Epoch: 30 | Training Loss: 0.1804381161928177 | Test Loss: 0.34932616353034973 | Test acc: 88.23529411764706% Elapsed: 638.1728940010071s
plt.plot(test_losses, label="test")
plt.plot(train_losses, label="train")
plt.legend()
plt.show()
#plt.plot(train_accuracies, label="train")
plt.plot(test_accuracies, label="test")
plt.legend()
plt.show()
fig = plt.figure(figsize=(9, 9))
rows, cols = 4, 4
for i in range(1, rows*cols+1):
random_idx = torch.randint(0, len(test_dataset), size=[1]).item()
image, true_label = test_dataset[random_idx]
predicted_label = model(image.unsqueeze(dim=0).to(device))
predicted_label = (predicted_label > 0).int().squeeze()
fig.add_subplot(rows, cols, i)
plt.imshow(image.permute(1, 2, 0), cmap="gray")
if true_label == predicted_label:
color = "g"
else:
color = "r"
plt.title(f"True: {classes[true_label]} | Pred: {classes[predicted_label]}", fontsize=8, c=color)
plt.axis(False)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
try:
from torchmetrics import ConfusionMatrix
except:
!pip install torchmetrics
from torchmetrics import ConfusionMatrix
try:
from mlxtend.plotting import plot_confusion_matrix
except:
!pip install mlxtend
from mlxtend.plotting import plot_confusion_matrix
y_preds = []
true_labels = []
for i in tqdm(range(len(test_dataset))):
image, true_label = test_dataset[i]
predicted_label = model(image.unsqueeze(dim=0).to(device))
predicted_label = (predicted_label > 0).int().squeeze()
y_preds.append(predicted_label)
true_labels.append(true_label)
y_preds = torch.tensor(y_preds).squeeze()
true_labels = torch.tensor(true_labels).squeeze()
confmat = ConfusionMatrix(task="multiclass", num_classes=2)
matrix = np.array(confmat(y_preds, true_labels))
fig, ax = plot_confusion_matrix(conf_mat=matrix, colorbar=True, show_absolute=False, show_normed=True, class_names=classes)
plt.show()
0%| | 0/51 [00:00<?, ?it/s]
print([loss.item() for loss in train_losses])
print([loss.item() for loss in test_losses])
print(test_accuracies)
print([[int(i) for i in row]for row in matrix])
[0.6329338550567627, 0.4774479866027832, 0.4524262845516205, 0.38824909925460815, 0.3458859324455261, 0.3082960546016693, 0.3031310439109802, 0.2722262740135193, 0.26732656359672546, 0.2754639685153961, 0.2346958965063095, 0.24143975973129272, 0.2435234934091568, 0.19618511199951172, 0.23904776573181152, 0.1947181671857834, 0.2019444853067398, 0.21457834541797638, 0.21077598631381989, 0.20026259124279022, 0.19900424778461456, 0.20062680542469025, 0.19526785612106323, 0.20533517003059387, 0.1718456745147705, 0.155961811542511, 0.17319394648075104, 0.15943369269371033, 0.17207781970500946, 0.1804381161928177] [0.5103597044944763, 0.4821828305721283, 0.45745745301246643, 0.4513978958129883, 0.4499729871749878, 0.42737245559692383, 0.4216597378253937, 0.4108864367008209, 0.4054611921310425, 0.4053954482078552, 0.40852028131484985, 0.3937559723854065, 0.38772913813591003, 0.3964950144290924, 0.4051493704319, 0.3928598165512085, 0.38535916805267334, 0.40606990456581116, 0.39269891381263733, 0.387015700340271, 0.3949466049671173, 0.385182648897171, 0.38639727234840393, 0.3713395893573761, 0.3612331748008728, 0.3814121186733246, 0.3787640631198883, 0.37068402767181396, 0.35953351855278015, 0.34932616353034973] [76.47058823529412, 72.54901960784314, 78.43137254901961, 80.3921568627451, 82.3529411764706, 84.31372549019608, 84.31372549019608, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 84.31372549019608, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 86.27450980392157, 88.23529411764706, 86.27450980392157, 86.27450980392157, 86.27450980392157, 88.23529411764706, 88.23529411764706] [[15, 4], [2, 30]]