Breaking the ice post - let's start blogging about medical imaging :)
To start off the blog, I've chosen the most basic example I could come up with:
Medical imaging categorization based on comparison between the "statistically average" image from a category and a set of test images.
This will be used to build upon using more advanced techniques, so stay tuned!
But first let's download the data
! rm -rf ./medical_mnist
! git clone https://github.com/apolanco3225/Medical-MNIST-Classification.git
! mv Medical-MNIST-Classification/resized/ ./medical_mnist
! rm -rf Medical-MNIST-Classification
Cloning into 'Medical-MNIST-Classification'... remote: Enumerating objects: 58532, done. remote: Total 58532 (delta 0), reused 0 (delta 0), pack-reused 58532 Receiving objects: 100% (58532/58532), 77.86 MiB | 4.39 MiB/s, done. Resolving deltas: 100% (506/506), done. Checking connectivity... done. Checking out files: 100% (58959/58959), done.
install useful libraries
## run this if you don't have pytorch and fastai2 installed
# !pip install torch torchvision
data will be downloaded to medical_mnist
folder
from pathlib import Path
data = Path('medical_mnist')
list(data.iterdir())
[PosixPath('medical_mnist/AbdomenCT'), PosixPath('medical_mnist/BreastMRI'), PosixPath('medical_mnist/CXR'), PosixPath('medical_mnist/ChestCT'), PosixPath('medical_mnist/Hand'), PosixPath('medical_mnist/HeadCT')]
let's see what we have here... as this is the most basic technique, let's pick the images that look the most different from each other
import matplotlib.pyplot as plt
from PIL import Image
for d in data.iterdir():
print(d)
plt.imshow(Image.open(list(d.iterdir())[0]))
plt.show()
medical_mnist/AbdomenCT
medical_mnist/BreastMRI
medical_mnist/CXR
medical_mnist/ChestCT
medical_mnist/Hand
medical_mnist/HeadCT
load the data into tensors
import torch
from torchvision.transforms import ToTensor
stacked_cxrs = torch.stack([ToTensor()(Image.open(path)).float()/255 for path in (data/'CXR').iterdir()])
stacked_heads = torch.stack([ToTensor()(Image.open(path)).float()/255 for path in (data/'HeadCT').iterdir()])
as a good practice, let's look at the first image, so see if we did it correctly
plt.imshow(stacked_cxrs[0][0])
<matplotlib.image.AxesImage at 0x7f2198995610>
now, let's build "ideal" image for each of the category. This ideal image is just a mean for each pixel across all the images
mean_cxrs = stacked_cxrs.mean(0)
plt.imshow(mean_cxrs[0])
<matplotlib.image.AxesImage at 0x7f21717d1590>
mean_headct = stacked_heads.mean(0)
plt.imshow(mean_headct[0])
<matplotlib.image.AxesImage at 0x7f211f6b6a50>
now we can see how much example image differs from the ideals:
import torch.nn.functional as F
F.mse_loss(stacked_cxrs[0], mean_cxrs).sqrt()
tensor(0.0010)
F.mse_loss(stacked_cxrs[0], mean_headct).sqrt()
tensor(0.0016)
looks like that one was a CXR indeed - L2 loss between ideal image from CXR category (mean_cxrs) was lower
so let's build a simple classifier function, that predicts whether image is a headct or not
def is_headct(img_tensor):
if F.mse_loss(img_tensor, mean_cxrs) > F.mse_loss(img_tensor, mean_headct):
return True
else:
return False
now we test the classifier
cxrs_preds = torch.tensor([not is_headct(stacked_cxrs[i]) for i in range(stacked_cxrs.shape[0])])
cxrs_accuracy = cxrs_preds.sum().float() / cxrs_preds.shape[0]
print(f'Accuracy on CXRs: {round( (cxrs_accuracy).item() * 100, 2)}%')
Accuracy on CXRs: 99.15%
head_preds = torch.tensor([is_headct(stacked_heads[i]) for i in range(stacked_heads.shape[0])])
head_accuracy = head_preds.sum().float() / head_preds.shape[0]
print(f'Accuracy on HeadCTs: {round( (head_accuracy).item() * 100, 2)}%')
Accuracy on HeadCTs: 100.0%
... of course:
train
and test
sets here so results are biased (as each image we predict was used to figure out the "ideal" image)but this was a nice start of this blog :)