import os
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split
# Continue with regular imports
import matplotlib.pyplot as plt
# Try to get torchinfo for summary of model, install it if it doesn't work
try:
from torchinfo import summary
except:
!pip install -q torchinfo
from torchinfo import summary
from utils.accuracy import accuracy
from utils.train_test import *
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cuda'
# All models are pretrained on IMAGENET1K
train_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.RandomRotation(degrees=(-15, 15)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
DATA_DIR = 'brain_dataset'
SPLIT_RATIO = 0.8
# Import dataset
dataset = datasets.ImageFolder(root=DATA_DIR)
# Split dataset
train_size = int(SPLIT_RATIO * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
#Assign transforms
train_dataset.dataset.transform = train_transform
test_dataset.dataset.transform = test_transform
classes = dataset.classes
len(train_dataset), len(test_dataset), classes
(201, 51, ['no', 'yes'])
BATCH_SIZE = 32
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())
train_dataloader, test_dataloader
(<torch.utils.data.dataloader.DataLoader at 0x29ac42bce50>, <torch.utils.data.dataloader.DataLoader at 0x29ac42a2710>)
fig = plt.figure(figsize=(9, 9))
rows, cols = 4, 4
for i in range(1, rows*cols+1):
random_idx = torch.randint(0, len(train_dataset), size=[1]).item()
image, true_label = train_dataset[random_idx]
fig.add_subplot(rows, cols, i)
plt.imshow(image.permute(1, 2, 0), cmap="gray")
plt.title(f"{classes[true_label]}", fontsize=8)
plt.axis(False)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
weights = torchvision.models.MobileNet_V3_Large_Weights.IMAGENET1K_V2
model = torchvision.models.mobilenet_v3_large(weights=weights)
model.to(device)
model.classifier
Sequential( (0): Linear(in_features=960, out_features=1280, bias=True) (1): Hardswish() (2): Dropout(p=0.2, inplace=True) (3): Linear(in_features=1280, out_features=1000, bias=True) )
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
============================================================================================================================================ Layer (type (var_name)) Input Shape Output Shape Param # Trainable ============================================================================================================================================ MobileNetV3 (MobileNetV3) [1, 3, 224, 224] [1, 1000] -- True ├─Sequential (features) [1, 3, 224, 224] [1, 960, 7, 7] -- True │ └─Conv2dNormActivation (0) [1, 3, 224, 224] [1, 16, 112, 112] -- True │ │ └─Conv2d (0) [1, 3, 224, 224] [1, 16, 112, 112] 432 True │ │ └─BatchNorm2d (1) [1, 16, 112, 112] [1, 16, 112, 112] 32 True │ │ └─Hardswish (2) [1, 16, 112, 112] [1, 16, 112, 112] -- -- │ └─InvertedResidual (1) [1, 16, 112, 112] [1, 16, 112, 112] -- True │ │ └─Sequential (block) [1, 16, 112, 112] [1, 16, 112, 112] 464 True │ └─InvertedResidual (2) [1, 16, 112, 112] [1, 24, 56, 56] -- True │ │ └─Sequential (block) [1, 16, 112, 112] [1, 24, 56, 56] 3,440 True │ └─InvertedResidual (3) [1, 24, 56, 56] [1, 24, 56, 56] -- True │ │ └─Sequential (block) [1, 24, 56, 56] [1, 24, 56, 56] 4,440 True │ └─InvertedResidual (4) [1, 24, 56, 56] [1, 40, 28, 28] -- True │ │ └─Sequential (block) [1, 24, 56, 56] [1, 40, 28, 28] 10,328 True │ └─InvertedResidual (5) [1, 40, 28, 28] [1, 40, 28, 28] -- True │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] 20,992 True │ └─InvertedResidual (6) [1, 40, 28, 28] [1, 40, 28, 28] -- True │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] 20,992 True │ └─InvertedResidual (7) [1, 40, 28, 28] [1, 80, 14, 14] -- True │ │ └─Sequential (block) [1, 40, 28, 28] [1, 80, 14, 14] 32,080 True │ └─InvertedResidual (8) [1, 80, 14, 14] [1, 80, 14, 14] -- True │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] 34,760 True │ └─InvertedResidual (9) [1, 80, 14, 14] [1, 80, 14, 14] -- True │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] 31,992 True │ └─InvertedResidual (10) [1, 80, 14, 14] [1, 80, 14, 14] -- True │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] 31,992 True │ └─InvertedResidual (11) [1, 80, 14, 14] [1, 112, 14, 14] -- True │ │ └─Sequential (block) [1, 80, 14, 14] [1, 112, 14, 14] 214,424 True │ └─InvertedResidual (12) [1, 112, 14, 14] [1, 112, 14, 14] -- True │ │ └─Sequential (block) [1, 112, 14, 14] [1, 112, 14, 14] 386,120 True │ └─InvertedResidual (13) [1, 112, 14, 14] [1, 160, 7, 7] -- True │ │ └─Sequential (block) [1, 112, 14, 14] [1, 160, 7, 7] 429,224 True │ └─InvertedResidual (14) [1, 160, 7, 7] [1, 160, 7, 7] -- True │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] 797,360 True │ └─InvertedResidual (15) [1, 160, 7, 7] [1, 160, 7, 7] -- True │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] 797,360 True │ └─Conv2dNormActivation (16) [1, 160, 7, 7] [1, 960, 7, 7] -- True │ │ └─Conv2d (0) [1, 160, 7, 7] [1, 960, 7, 7] 153,600 True │ │ └─BatchNorm2d (1) [1, 960, 7, 7] [1, 960, 7, 7] 1,920 True │ │ └─Hardswish (2) [1, 960, 7, 7] [1, 960, 7, 7] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 960, 7, 7] [1, 960, 1, 1] -- -- ├─Sequential (classifier) [1, 960] [1, 1000] -- True │ └─Linear (0) [1, 960] [1, 1280] 1,230,080 True │ └─Hardswish (1) [1, 1280] [1, 1280] -- -- │ └─Dropout (2) [1, 1280] [1, 1280] -- -- │ └─Linear (3) [1, 1280] [1, 1000] 1,281,000 True ============================================================================================================================================ Total params: 5,483,032 Trainable params: 5,483,032 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 216.62 ============================================================================================================================================ Input size (MB): 0.60 Forward/backward pass size (MB): 70.46 Params size (MB): 21.93 Estimated Total Size (MB): 92.99 ============================================================================================================================================
for param in model.features.parameters():
param.requires_grad = False
for i, param in enumerate(model.classifier.parameters()):
if i<2:
param.requires_grad = False
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
============================================================================================================================================ Layer (type (var_name)) Input Shape Output Shape Param # Trainable ============================================================================================================================================ MobileNetV3 (MobileNetV3) [1, 3, 224, 224] [1, 1000] -- Partial ├─Sequential (features) [1, 3, 224, 224] [1, 960, 7, 7] -- False │ └─Conv2dNormActivation (0) [1, 3, 224, 224] [1, 16, 112, 112] -- False │ │ └─Conv2d (0) [1, 3, 224, 224] [1, 16, 112, 112] (432) False │ │ └─BatchNorm2d (1) [1, 16, 112, 112] [1, 16, 112, 112] (32) False │ │ └─Hardswish (2) [1, 16, 112, 112] [1, 16, 112, 112] -- -- │ └─InvertedResidual (1) [1, 16, 112, 112] [1, 16, 112, 112] -- False │ │ └─Sequential (block) [1, 16, 112, 112] [1, 16, 112, 112] (464) False │ └─InvertedResidual (2) [1, 16, 112, 112] [1, 24, 56, 56] -- False │ │ └─Sequential (block) [1, 16, 112, 112] [1, 24, 56, 56] (3,440) False │ └─InvertedResidual (3) [1, 24, 56, 56] [1, 24, 56, 56] -- False │ │ └─Sequential (block) [1, 24, 56, 56] [1, 24, 56, 56] (4,440) False │ └─InvertedResidual (4) [1, 24, 56, 56] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 24, 56, 56] [1, 40, 28, 28] (10,328) False │ └─InvertedResidual (5) [1, 40, 28, 28] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] (20,992) False │ └─InvertedResidual (6) [1, 40, 28, 28] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] (20,992) False │ └─InvertedResidual (7) [1, 40, 28, 28] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 80, 14, 14] (32,080) False │ └─InvertedResidual (8) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (34,760) False │ └─InvertedResidual (9) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (31,992) False │ └─InvertedResidual (10) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (31,992) False │ └─InvertedResidual (11) [1, 80, 14, 14] [1, 112, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 112, 14, 14] (214,424) False │ └─InvertedResidual (12) [1, 112, 14, 14] [1, 112, 14, 14] -- False │ │ └─Sequential (block) [1, 112, 14, 14] [1, 112, 14, 14] (386,120) False │ └─InvertedResidual (13) [1, 112, 14, 14] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 112, 14, 14] [1, 160, 7, 7] (429,224) False │ └─InvertedResidual (14) [1, 160, 7, 7] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] (797,360) False │ └─InvertedResidual (15) [1, 160, 7, 7] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] (797,360) False │ └─Conv2dNormActivation (16) [1, 160, 7, 7] [1, 960, 7, 7] -- False │ │ └─Conv2d (0) [1, 160, 7, 7] [1, 960, 7, 7] (153,600) False │ │ └─BatchNorm2d (1) [1, 960, 7, 7] [1, 960, 7, 7] (1,920) False │ │ └─Hardswish (2) [1, 960, 7, 7] [1, 960, 7, 7] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 960, 7, 7] [1, 960, 1, 1] -- -- ├─Sequential (classifier) [1, 960] [1, 1000] -- Partial │ └─Linear (0) [1, 960] [1, 1280] (1,230,080) False │ └─Hardswish (1) [1, 1280] [1, 1280] -- -- │ └─Dropout (2) [1, 1280] [1, 1280] -- -- │ └─Linear (3) [1, 1280] [1, 1000] 1,281,000 True ============================================================================================================================================ Total params: 5,483,032 Trainable params: 1,281,000 Non-trainable params: 4,202,032 Total mult-adds (Units.MEGABYTES): 216.62 ============================================================================================================================================ Input size (MB): 0.60 Forward/backward pass size (MB): 70.46 Params size (MB): 21.93 Estimated Total Size (MB): 92.99 ============================================================================================================================================
model.classifier[3] = nn.Linear(1280, 1)
summary(model=model, input_size = (1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
============================================================================================================================================ Layer (type (var_name)) Input Shape Output Shape Param # Trainable ============================================================================================================================================ MobileNetV3 (MobileNetV3) [1, 3, 224, 224] [1, 1] -- Partial ├─Sequential (features) [1, 3, 224, 224] [1, 960, 7, 7] -- False │ └─Conv2dNormActivation (0) [1, 3, 224, 224] [1, 16, 112, 112] -- False │ │ └─Conv2d (0) [1, 3, 224, 224] [1, 16, 112, 112] (432) False │ │ └─BatchNorm2d (1) [1, 16, 112, 112] [1, 16, 112, 112] (32) False │ │ └─Hardswish (2) [1, 16, 112, 112] [1, 16, 112, 112] -- -- │ └─InvertedResidual (1) [1, 16, 112, 112] [1, 16, 112, 112] -- False │ │ └─Sequential (block) [1, 16, 112, 112] [1, 16, 112, 112] (464) False │ └─InvertedResidual (2) [1, 16, 112, 112] [1, 24, 56, 56] -- False │ │ └─Sequential (block) [1, 16, 112, 112] [1, 24, 56, 56] (3,440) False │ └─InvertedResidual (3) [1, 24, 56, 56] [1, 24, 56, 56] -- False │ │ └─Sequential (block) [1, 24, 56, 56] [1, 24, 56, 56] (4,440) False │ └─InvertedResidual (4) [1, 24, 56, 56] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 24, 56, 56] [1, 40, 28, 28] (10,328) False │ └─InvertedResidual (5) [1, 40, 28, 28] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] (20,992) False │ └─InvertedResidual (6) [1, 40, 28, 28] [1, 40, 28, 28] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 40, 28, 28] (20,992) False │ └─InvertedResidual (7) [1, 40, 28, 28] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 40, 28, 28] [1, 80, 14, 14] (32,080) False │ └─InvertedResidual (8) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (34,760) False │ └─InvertedResidual (9) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (31,992) False │ └─InvertedResidual (10) [1, 80, 14, 14] [1, 80, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 80, 14, 14] (31,992) False │ └─InvertedResidual (11) [1, 80, 14, 14] [1, 112, 14, 14] -- False │ │ └─Sequential (block) [1, 80, 14, 14] [1, 112, 14, 14] (214,424) False │ └─InvertedResidual (12) [1, 112, 14, 14] [1, 112, 14, 14] -- False │ │ └─Sequential (block) [1, 112, 14, 14] [1, 112, 14, 14] (386,120) False │ └─InvertedResidual (13) [1, 112, 14, 14] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 112, 14, 14] [1, 160, 7, 7] (429,224) False │ └─InvertedResidual (14) [1, 160, 7, 7] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] (797,360) False │ └─InvertedResidual (15) [1, 160, 7, 7] [1, 160, 7, 7] -- False │ │ └─Sequential (block) [1, 160, 7, 7] [1, 160, 7, 7] (797,360) False │ └─Conv2dNormActivation (16) [1, 160, 7, 7] [1, 960, 7, 7] -- False │ │ └─Conv2d (0) [1, 160, 7, 7] [1, 960, 7, 7] (153,600) False │ │ └─BatchNorm2d (1) [1, 960, 7, 7] [1, 960, 7, 7] (1,920) False │ │ └─Hardswish (2) [1, 960, 7, 7] [1, 960, 7, 7] -- -- ├─AdaptiveAvgPool2d (avgpool) [1, 960, 7, 7] [1, 960, 1, 1] -- -- ├─Sequential (classifier) [1, 960] [1, 1] -- Partial │ └─Linear (0) [1, 960] [1, 1280] (1,230,080) False │ └─Hardswish (1) [1, 1280] [1, 1280] -- -- │ └─Dropout (2) [1, 1280] [1, 1280] -- -- │ └─Linear (3) [1, 1280] [1, 1] 1,281 True ============================================================================================================================================ Total params: 4,203,313 Trainable params: 1,281 Non-trainable params: 4,202,032 Total mult-adds (Units.MEGABYTES): 215.34 ============================================================================================================================================ Input size (MB): 0.60 Forward/backward pass size (MB): 70.45 Params size (MB): 16.81 Estimated Total Size (MB): 87.86 ============================================================================================================================================
loss_fn = nn.BCEWithLogitsLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
from tqdm.auto import tqdm
import time
train_losses = []
test_losses = []
test_accuracies = []
EPOCHS = 30
program_starts = time.time()
for epoch in tqdm(range(EPOCHS)):
train_loss = train_loop(train_dataloader, model, loss_fn, optimizer, device)
test_loss, test_accuracy = test_loop(test_dataloader, model, loss_fn, device)
train_losses.append(train_loss)
test_losses.append(test_loss)
test_accuracies.append(test_accuracy)
print(f"Epoch: {epoch+1} | Training Loss: {train_loss} | Test Loss: {test_loss} | Test acc: {test_accuracy}%")
program_ends = time.time()
print(f"Elapsed: {program_ends-program_starts}s")
0%| | 0/30 [00:00<?, ?it/s]
Epoch: 1 | Training Loss: 0.6668773889541626 | Test Loss: 0.6969553828239441 | Test acc: 49.01960784313726% Epoch: 2 | Training Loss: 0.5958903431892395 | Test Loss: 0.6708950400352478 | Test acc: 50.98039215686274% Epoch: 3 | Training Loss: 0.5542582869529724 | Test Loss: 0.6492452621459961 | Test acc: 58.8235294117647% Epoch: 4 | Training Loss: 0.5163666605949402 | Test Loss: 0.6294794678688049 | Test acc: 60.78431372549019% Epoch: 5 | Training Loss: 0.488847017288208 | Test Loss: 0.6105336546897888 | Test acc: 62.745098039215684% Epoch: 6 | Training Loss: 0.4653830826282501 | Test Loss: 0.5923246145248413 | Test acc: 62.745098039215684% Epoch: 7 | Training Loss: 0.4295113980770111 | Test Loss: 0.5752974152565002 | Test acc: 64.70588235294117% Epoch: 8 | Training Loss: 0.4152264893054962 | Test Loss: 0.5597549676895142 | Test acc: 66.66666666666667% Epoch: 9 | Training Loss: 0.39757809042930603 | Test Loss: 0.5442760586738586 | Test acc: 66.66666666666667% Epoch: 10 | Training Loss: 0.38699790835380554 | Test Loss: 0.530937671661377 | Test acc: 72.54901960784314% Epoch: 11 | Training Loss: 0.366444855928421 | Test Loss: 0.5186097025871277 | Test acc: 72.54901960784314% Epoch: 12 | Training Loss: 0.35914748907089233 | Test Loss: 0.5050334930419922 | Test acc: 72.54901960784314% Epoch: 13 | Training Loss: 0.34462764859199524 | Test Loss: 0.4937037527561188 | Test acc: 70.58823529411765% Epoch: 14 | Training Loss: 0.33720025420188904 | Test Loss: 0.4853236675262451 | Test acc: 74.50980392156863% Epoch: 15 | Training Loss: 0.3367728590965271 | Test Loss: 0.4772041440010071 | Test acc: 76.47058823529412% Epoch: 16 | Training Loss: 0.31791597604751587 | Test Loss: 0.4677819013595581 | Test acc: 78.43137254901961% Epoch: 17 | Training Loss: 0.31470802426338196 | Test Loss: 0.45909610390663147 | Test acc: 76.47058823529412% Epoch: 18 | Training Loss: 0.3018851578235626 | Test Loss: 0.4508737325668335 | Test acc: 78.43137254901961% Epoch: 19 | Training Loss: 0.3013880252838135 | Test Loss: 0.4454551637172699 | Test acc: 78.43137254901961% Epoch: 20 | Training Loss: 0.2897450625896454 | Test Loss: 0.4404674470424652 | Test acc: 82.3529411764706% Epoch: 21 | Training Loss: 0.2797577679157257 | Test Loss: 0.4358077943325043 | Test acc: 80.3921568627451% Epoch: 22 | Training Loss: 0.27962398529052734 | Test Loss: 0.43255695700645447 | Test acc: 80.3921568627451% Epoch: 23 | Training Loss: 0.2851901054382324 | Test Loss: 0.42878907918930054 | Test acc: 78.43137254901961% Epoch: 24 | Training Loss: 0.2639926075935364 | Test Loss: 0.4241710901260376 | Test acc: 84.31372549019608% Epoch: 25 | Training Loss: 0.25456246733665466 | Test Loss: 0.4191741645336151 | Test acc: 84.31372549019608% Epoch: 26 | Training Loss: 0.23348985612392426 | Test Loss: 0.4147363305091858 | Test acc: 88.23529411764706% Epoch: 27 | Training Loss: 0.2518940269947052 | Test Loss: 0.41209539771080017 | Test acc: 88.23529411764706% Epoch: 28 | Training Loss: 0.23118329048156738 | Test Loss: 0.4108111262321472 | Test acc: 88.23529411764706% Epoch: 29 | Training Loss: 0.2713697552680969 | Test Loss: 0.40884706377983093 | Test acc: 88.23529411764706% Epoch: 30 | Training Loss: 0.252165287733078 | Test Loss: 0.4056894779205322 | Test acc: 88.23529411764706% Elapsed: 627.1224312782288s
plt.plot(test_losses, label="test")
plt.plot(train_losses, label="train")
plt.legend()
plt.show()
#plt.plot(train_accuracies, label="train")
plt.plot(test_accuracies, label="test")
plt.legend()
plt.show()
fig = plt.figure(figsize=(9, 9))
rows, cols = 4, 4
for i in range(1, rows*cols+1):
random_idx = torch.randint(0, len(test_dataset), size=[1]).item()
image, true_label = test_dataset[random_idx]
predicted_label = model(image.unsqueeze(dim=0).to(device))
predicted_label = (predicted_label > 0).int().squeeze()
fig.add_subplot(rows, cols, i)
plt.imshow(image.permute(1, 2, 0), cmap="gray")
if true_label == predicted_label:
color = "g"
else:
color = "r"
plt.title(f"True: {classes[true_label]} | Pred: {classes[predicted_label]}", fontsize=8, c=color)
plt.axis(False)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
try:
from torchmetrics import ConfusionMatrix
except:
!pip install torchmetrics
from torchmetrics import ConfusionMatrix
try:
from mlxtend.plotting import plot_confusion_matrix
except:
!pip install mlxtend
from mlxtend.plotting import plot_confusion_matrix
y_preds = []
true_labels = []
for i in tqdm(range(len(test_dataset))):
image, true_label = test_dataset[i]
predicted_label = model(image.unsqueeze(dim=0).to(device))
predicted_label = (predicted_label > 0).int().squeeze()
y_preds.append(predicted_label)
true_labels.append(true_label)
y_preds = torch.tensor(y_preds).squeeze()
true_labels = torch.tensor(true_labels).squeeze()
confmat = ConfusionMatrix(task="multiclass", num_classes=2)
matrix = np.array(confmat(y_preds, true_labels))
fig, ax = plot_confusion_matrix(conf_mat=matrix, colorbar=True, show_absolute=False, show_normed=True, class_names=classes)
plt.show()
0%| | 0/51 [00:00<?, ?it/s]
print([loss.item() for loss in train_losses])
print([loss.item() for loss in test_losses])
print(test_accuracies)
print([[int(i) for i in row]for row in matrix])
[0.6668773889541626, 0.5958903431892395, 0.5542582869529724, 0.5163666605949402, 0.488847017288208, 0.4653830826282501, 0.4295113980770111, 0.4152264893054962, 0.39757809042930603, 0.38699790835380554, 0.366444855928421, 0.35914748907089233, 0.34462764859199524, 0.33720025420188904, 0.3367728590965271, 0.31791597604751587, 0.31470802426338196, 0.3018851578235626, 0.3013880252838135, 0.2897450625896454, 0.2797577679157257, 0.27962398529052734, 0.2851901054382324, 0.2639926075935364, 0.25456246733665466, 0.23348985612392426, 0.2518940269947052, 0.23118329048156738, 0.2713697552680969, 0.252165287733078] [0.6969553828239441, 0.6708950400352478, 0.6492452621459961, 0.6294794678688049, 0.6105336546897888, 0.5923246145248413, 0.5752974152565002, 0.5597549676895142, 0.5442760586738586, 0.530937671661377, 0.5186097025871277, 0.5050334930419922, 0.4937037527561188, 0.4853236675262451, 0.4772041440010071, 0.4677819013595581, 0.45909610390663147, 0.4508737325668335, 0.4454551637172699, 0.4404674470424652, 0.4358077943325043, 0.43255695700645447, 0.42878907918930054, 0.4241710901260376, 0.4191741645336151, 0.4147363305091858, 0.41209539771080017, 0.4108111262321472, 0.40884706377983093, 0.4056894779205322] [49.01960784313726, 50.98039215686274, 58.8235294117647, 60.78431372549019, 62.745098039215684, 62.745098039215684, 64.70588235294117, 66.66666666666667, 66.66666666666667, 72.54901960784314, 72.54901960784314, 72.54901960784314, 70.58823529411765, 74.50980392156863, 76.47058823529412, 78.43137254901961, 76.47058823529412, 78.43137254901961, 78.43137254901961, 82.3529411764706, 80.3921568627451, 80.3921568627451, 78.43137254901961, 84.31372549019608, 84.31372549019608, 88.23529411764706, 88.23529411764706, 88.23529411764706, 88.23529411764706, 88.23529411764706] [[23, 4], [2, 22]]