Figure 5 in our post shows the DCGAN architecture. We’ll need to implement our discriminator, generator, data loading, and training code. I usually prefer to start with the the data loading piece - after all, without data we can’t do much! We’ll start by writing a simple script to pull the data from the MVTec website.
!pip install jupyter
!pip install matplotlib
!pip install fastai
!pip install wget
!pip install kornia
!pip install opencv-python
import sys, wget, tarfile, os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
def simple_progress_bar(current, total, width=80):
progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
sys.stdout.write("\r" + progress_message)
sys.stdout.flush()
def get_mvtech_dataset(data_dir, dataset_name):
data_dir.mkdir(exist_ok=True)
if not (data_dir/('%s.tar.xz'%dataset_name)).exists():
wget.download('ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/%s.tar.xz'%dataset_name, \
out=str(data_dir/('%s.tar.xz'%dataset_name)), bar=simple_progress_bar)
if not (data_dir/dataset_name).exists():
tar=tarfile.open(data_dir/('%s.tar.xz'%dataset_name))
tar.extractall(data_dir)
tar.close()
Now, which product class should we experiment with? For some reason, I just really enjoy the aesthetics of the hazelnut class - these images have a cool retro feel to them to me. This of course has nothing to do with how out model will perform - and please feel free to experiment with different classes!
data_path=Path('data')
dset='hazelnut'
get_mvtech_dataset(data_path, dset)
Downloading: 90% [558080000 / 617098680] bytes
After our script runs, we can have a look at the structure of our data.
list((data_path/dset).glob('*'))
[PosixPath('data/hazelnut/test'), PosixPath('data/hazelnut/train'), PosixPath('data/hazelnut/readme.txt'), PosixPath('data/hazelnut/license.txt'), PosixPath('data/hazelnut/ground_truth')]
We're provided a train
, test
, and ground_truth
(labels) folder, and within the train it looks like we're given 391 examples of non-defective hazelnuts:
list((data_path/dset/'train').glob('*'))
[PosixPath('data/hazelnut/train/good')]
im_paths=list((data_path/dset/'train'/'good').glob('*'))
im_paths[:5] #Look at first 5 paths
[PosixPath('data/hazelnut/train/good/177.png'), PosixPath('data/hazelnut/train/good/136.png'), PosixPath('data/hazelnut/train/good/079.png'), PosixPath('data/hazelnut/train/good/234.png'), PosixPath('data/hazelnut/train/good/357.png')]
len(im_paths) #How many examples do we have?
391
plt.imshow(plt.imread(str(im_paths[0])))
plt.title('One sexy hazelnut.');
Our test folder has a bit more going on, and includes examples of a 4 defect classes, and some more good examples to use in testing. We'll set aside our test set for now, and come back to it after training our GAN.
list((data_path/dset/'test').glob('*'))
[PosixPath('data/hazelnut/test/print'), PosixPath('data/hazelnut/test/crack'), PosixPath('data/hazelnut/test/hole'), PosixPath('data/hazelnut/test/good'), PosixPath('data/hazelnut/test/cut')]
We’ll be using a few python libraries of note here: PyTorch, kornia, and fastai. There’s of course other fantastic tools out there like Tensorflow and Keras - PyTorch and fastai have been my go to for the last 2 years or so, especially for getting stuff up and running quickly.
import kornia
from fastai.vision import *
# Might need/want to supress warnings if your fastai and pytorch versions dont quite agree
import warnings
warnings.filterwarnings('ignore')
One thing I particularly like about fastai is the built in dataloader classes, called DataBunches. Let's create one for our training set. We'll create a a dataloader that will return minibatches of size 128, downsample our images to 64x64 (the resoultion used in the DCGAN paper). Our databunch
will also take care of data augmentation and normalization.
batch_size, im_size, channels = 64, 64, 3
tfms = ([*rand_pad(padding=3, size=im_size, mode='border')], [])
data = ImageList.from_folder(data_path/dset/'train'/'good').split_none() \
.label_empty() \
.transform(tfms, size=im_size) \
.databunch(bs=batch_size) \
.normalize((0.5, 0.5))
data.show_batch(figsize=(8, 8))
Let's have a quick look at the scale of our data.
x, y=data.one_batch()
plt.hist(x.numpy().ravel(),100); plt.grid(1)
One thing I love about fastai is that it's a really light wrapper on PyTorch. There's no fastai implemenation of DCGAN, so we have to build it ourselves. Happily, we can do this in PyTorch, and still take advantage of the fastai dataloaders and training code. Before we dive in, let's figure out if we're going to be using a GPU of CPU for training. I highly recommend training this on GPU, CPU is fine for experimentation, but will take quite some time to train the full DCGAN model.
import torch
import torch.nn as nn
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #Do we have a GPU?
defaults.device = device
print(device)
cuda
Let's start with setting up our generator. Looking at figure X, we see that our generator is required to upsample between layers - this is generally achieved in Deep Learning models using transposed convolutional layers, also known as fractionally strided convolutional layers. There's a terrific paper on this by Vincent Dumoulin and Francesco Visin.
Note that in section 3 of the DCGAN paper, the authors call for batch normalization in the generator and discriminator, and for ReLU activations functions in every layer of the Generator, except for the output layer which uses tanh. We can make our code a bit more succint by create a conv_trans
method that we can use for each layer.
def conv_trans(ni, nf, ks=4, stride=2, padding=1):
return nn.Sequential(
nn.ConvTranspose2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=padding),
nn.BatchNorm2d(nf),
nn.ReLU(inplace = True))
Now we can put there layers together into our generator, following the number of filters shown in figure X.
G = nn.Sequential(
conv_trans(100, 1024, ks=4, stride=1, padding=0),
conv_trans(1024, 512),
conv_trans(512, 256),
conv_trans(256, 128),
nn.ConvTranspose2d(128, channels, 4, stride=2, padding=1),
nn.Tanh()).to(device)
Alright, so we have a Generator! Now, if you were paying attention in section 3.2 you may remember that the input to our generator during training is just random nosie vectors. And from these random noise vectors, our Generator is supposed to magically produce realisting looking images. Ok, so let's create a random vector:
z = torch.randn(1, 100, 1, 1)
And pass it into our Generator and see what we get!
fake = G(z.to(device))
plt.imshow(fake[0, 0].cpu().detach().numpy()); plt.grid(0)
Now, as you can see, our output really just looks like noise! This is of course because we haven't trained our GAN yet! When we're done this output should look like a hazelnut!
Now let's setup our discriminator. We'll use a similar patter of creating a subfunction that contains our convolution, batch normalization, and ReLU. Note here that the DCGAN authors have called for Leaky ReLU instead of ReLU.
def conv(ni, nf, ks=4, stride=2, padding=1):
return nn.Sequential(
nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=padding),
nn.BatchNorm2d(nf),
nn.LeakyReLU(0.2, inplace = True))
D = nn.Sequential(
conv(channels, 128),
conv(128, 256),
conv(256, 512),
conv(512, 1024),
nn.Conv2d(1024, 1, 4, stride=1, padding=0),
Flatten(),
nn.Sigmoid()).to(device)
Just as a quick sanity check, let's play with our discriminator and generator for a minute. We know that we can pass in random noise vectors into our generator and get crappy fake images out. Now, we should also be able to take these fake images and pass them into our discriminator. Let's try it!
fake=G(z.to(device))
fake.shape
torch.Size([1, 3, 64, 64])
D(fake)
tensor([[0.5238]], device='cuda:0', grad_fn=<SigmoidBackward>)
Just to test our thinking here - what is the meaning of this output value? Well, we create a fake image by passing in noise into our generator, and hav now passed that fake in into our discriminator, which is returning what it belives to the the probably that the image is real. Finally, just one more sanity check. As we train, we'll be passing in both real and fake data into our discriminator. We'll be getting our data from our fastai
data loader:
x,y=data.one_batch()
x.shape
torch.Size([64, 3, 64, 64])
out=D(x.to(device))
plt.plot(out.detach().cpu()); plt.grid(1)
So, what are we looking at here? Well, notice that the dimension of our minibatch is [128, 3, 64, 64]
, meaning we have 128 images we're analyzing at once in our discriminator. The outputs we've plotted are the probabilities of being real the discriminator has assigned to each image. As we can see, our results are all over the place - again, this is becuase we haven't trained anything yet. Once we're done, and effective discriminator should assign a probability close to one to each image.
Alright, now that we have our data loader, discriminator, and generator set up, we can train out model! It's really helpful to have some visualization as we train, especially to see if the fake image the generator is creating look convincing. Let's start by creating a performance visualization method to show performance as we train. We'll keep track of a few key visuals while training. First, we'll choose a z_fized
- 25 randomly chosen and static points in our latent space. At each visualization step, we'll pass these 25 points through our generator, and see how our fake images look. As we train, our random noise should start to be shaped into hazelnuts! Secondly, we'll plot a histogram of the pixel intensity of our fake images G(z)
and compare these to our histograms of the pixel intensity values in our real images x
. As we train, these distributions should look more and more similar. Finally, we'll also visualiztion our Generator and Discrimanator loss functions as we train - we should hopefully see a healthy back and forth, if either model consistenly wins, it's unlikely our fake images will look anything like real ones!
from torch import optim
from tqdm import tqdm
from IPython import display
import matplotlib.gridspec as gridspec
save_training_viz=True
save_dir=Path('data/exports') #Location to save training visualzations
(save_dir/'viz').mkdir(exist_ok=True, parents=True)
(save_dir/'ckpts').mkdir(exist_ok=True, parents=True)
def show_progress(save=False):
'''Visualization method to see how were doing'''
plt.clf(); fig=plt.figure(0, (24, 12)); gs=gridspec.GridSpec(6, 12)
with torch.no_grad(): fake=G(z_fixed)
for j in range(30):
fig.add_subplot(gs[(j//6), j%6])
plt.imshow((kornia.tensor_to_image(fake[j])+1)/2); plt.axis('off')
ax=fig.add_subplot(gs[5, :4]); plt.hist(fake.detach().cpu().numpy().ravel(), 100, facecolor='xkcd:crimson')
ax.get_yaxis().set_ticks([]); plt.xlabel('$G(z)$', fontsize=16); plt.xlim([-1, 1])
ax=fig.add_subplot(gs[5, 4:7]); plt.hist(x.cpu().numpy().ravel(), 100, facecolor='xkcd:purple')
ax.get_yaxis().set_ticks([]); plt.xlabel('$x$', fontsize=16)
fig.add_subplot(gs[:,7:])
plt.plot(losses[0], color='xkcd:goldenrod', linewidth=2); plt.plot(losses[1], color='xkcd:sea blue', linewidth=2);
plt.legend(['Discriminator', 'Generator'],loc=1, fontsize=16);
plt.grid(1); plt.title('Epoch = ' + str(epoch), fontsize=16); plt.ylabel('loss', fontsize=16); plt.xlabel('iteration', fontsize=16);
display.clear_output(wait=True); display.display(plt.gcf())
if save: plt.savefig(save_dir/'viz'/(str(count)+'.png'), dpi=150)
Now we'll setup our loss function and optimizers following DCGAN paper:
optD = optim.Adam(D.parameters(), lr=1e-4, betas = (0.5, 0.999))
optG = optim.Adam(G.parameters(), lr=1e-4, betas = (0.5, 0.999))
criterion = nn.BCELoss()
And finally we're ready to train! On tricky thing about GANs is that no on really knows when you should stop training. One reasonable, but kinda annoying appraoch is to monitor the appearance of the generated samples while training, that's what we'll do here, while taking periodic snapshots of our weights.
zero_labels = torch.zeros(batch_size).to(device)
ones_labels = torch.ones(batch_size).to(device)
losses = [[],[]]
epochs, viz_freq, save_freq, count = 10000, 100, 500, 0
z_fixed = torch.randn(batch_size, 100, 1, 1).to(device)
for epoch in range(epochs):
for i, (x,y) in enumerate(tqdm(data.train_dl)):
#Train Discriminator
requires_grad(G, False); #Speeds up training a smidge
z = torch.randn(batch_size, 100, 1, 1).to(device)
l_fake = criterion(D(G(z)).view(-1), zero_labels)
l_real = criterion(D(x).view(-1), ones_labels)
loss = l_fake + l_real
loss.backward(); losses[0].append(loss.item())
optD.step(); G.zero_grad(); D.zero_grad();
#Train Generator
requires_grad(G, True);
z = torch.randn(batch_size, 100, 1, 1).to(device)
loss = criterion(D(G(z)).view(-1), ones_labels)
loss.backward(); losses[1].append(loss.item())
optG.step(); G.zero_grad(); D.zero_grad();
if i%viz_freq==0: show_progress(save_training_viz)
count+=1
if (epoch+1)%save_freq==0:
torch.save(G, save_dir/'ckpts'/('G_epoch_'+str(epoch)+'.pth'))
torch.save(D, save_dir/'ckpts'/('D_epoch_'+str(epoch)+'.pth'))
100%|██████████| 6/6 [00:09<00:00, 1.50s/it] 0%| | 0/6 [00:00<?, ?it/s]
And in case you don't want do sit around waiting, here's a sped up trainig animation:
from IPython.display import HTML
HTML('<iframe width="1120" height="630" src="https://www.youtube.com/embed/33wVaI7NUPw?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
Alright, this blows my mind. The fact that we can do something really hard (make realistic fake images) is insane - it makes no sense. But there we have, not perfect, but pretty nice looking fake images of hazelnuts shaped into random noise! This rounds out the first, and most complex step in implementing the AnoGAN approach. A quick note on GAN training before we move on - it's incredibly finicky. Over the course of this writing, I trained 10-15 GANs, and the reproducability is laughable. This could be improved by using the same random seen each time. As we'll see, our GAN is the critical component that allows us to find anomalies, and in my experimens I found that some of my GANs did significantly better than others. My best results were with a batch size of 128 and 2500 epochs of training, even though this model didn't produce the most realistic fake images! I think this just goes to show that GANs are still very much in the reserach phase, and there's lots of exciting/interesting work to be done to make these models production grade.
Now that we can produce reasonably realistic fake images of hazelnuts - it's time to use our GAN to find anomalies!
##Load checkpoints if needed:
# My favorite checkpoint, 2500 Epochs, batch size=128
G=torch.load('weights/G.pth')
D=torch.load('weights/D.pth')
So we now have a generator capable of generating fake images of hazelnuts, and a discriminator that's going its very best to distinguish real from fake hazelnut images. Now what? How do we use these find anomalies?
This is where things get really really interesting. Remeber that comment earlier about the good images representing a manifold in a higher dimensional latent space of all images? We're now going to exploit this phenomenon to find anomalies. The idea is this. Given a new image x
, we're going to compute the representation the image in our GANs latent space, and if our image is on the manifold of good iamges, we're going to call it good. And if it's too far from our manifold, we'll call it an anomaly. Now, if you've been paying attention, you may be thinking someting like: "Stephen, our GAN can only compute take images from random noise vectors, we can't go the other way!". And to this I would say, you are completely correct. Neural nets don't generally go backwards :).
So what should we do then? Given a new image, we want to find it's representation in latent space. That is, the 100 dimensional vector that corresponds to that image. Solving this exact problem turns out to basically be impossible, but we can do something close. Let's grab a defective image and discuss the idea used by the AnoGAN authors and others to achiece this feat.
import cv2
im_path=(data_path/dset/'train'/'good').ls()[1]
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
plt.imshow(im)
<matplotlib.image.AxesImage at 0x7efecc34fbd0>
We would like to know if how for from the manifold of good images in our Generator's latent space this image lies. Unfortunately, we can't really compute this directly - but we can do something close - we can try to find the nearest point in our Generator's latent space - and measure how far away this image is from that point, giving us an idea of how well this image "fits" with our good images. We can find this nearest neighbor point using a very similar appraoch to our method for training our network. Specifically, we'll pick a random point in latent space, and using gradient descent, walk through latent space from that point guided by gradient descent. Our loss function is relatively simple, just the l1 loss between our generated image and our "query" image. This loss function is captured in equation 3 of the anogan paper:
Let's give this a shot. We'll pick a random staring point and train for 1000 iterations to reduce the l1 loss between our query image and generated image.
im_tensor=((kornia.image_to_tensor(im).float()/128)-1).to(device) #Scale image between -1 and +1
z=torch.randn(1, 100, 1, 1, requires_grad=True, device=device) #Random starting point in latent space
opt=optim.Adam([z], lr=2e-2)
losses=[]
for i in tqdm(range(1000)):
fake=G(z)
loss=torch.nn.L1Loss()(fake.squeeze(0), im_tensor)
loss.backward(); opt.step()
z.grad.zero_(); G.zero_grad()
losses.append(loss.item())
100%|██████████| 1000/1000 [00:06<00:00, 147.07it/s]
plt.plot(losses); plt.grid(1)
fig=plt.figure(0, (8,4))
fig.add_subplot(1,2,1)
plt.imshow(im); plt.title('Real Image')
fig.add_subplot(1,2,2)
plt.imshow((kornia.tensor_to_image(G(z))+1)/2); plt.title('Closest Fake Image')
Text(0.5, 1.0, 'Closest Fake Image')
So, how do our results look? Eh, ok. Not amazing, right? Ideally we should find the closest good image to our defective image, but our good image looks a little rough, right? We definitely lost some fidelity. The AnoGan authors use one more trick to address this issue. As you can imagine, this "optimizing backwards accross our generator" is far from a perfect science - and there's lots of things we can do to improve performance. The AnoGAN authors chose to do something really cool here that uses our discrimanator to help our optimization land on a more realistic nearest good image. Specifically, they choose to add a term to our loss function the measures the distance between our discriminator's representation of our real and fake images. Let's see how this works.
We're going to compare the representation of our real and fake images in the feature space our discrminator. The idea here is that through the training process our discriminator has learned lots of useful features (captured in its hidden layers) to measure the perceptual difference between real and fake images. For this reason, techniques like this are sometime call perceptual loss. The AnoGAN authors again use L1 loss, as shown in equation 4 of their paper:
where $f(\cdot)$ represents the intermediate layers of the discriminator. Let's implement this and add it to our approach. We'll start by pre-computing the intermediate represnetation of our real image, as this will not change during training:
f={} #Precompute feature values for layers
with torch.no_grad():
for i in range(1, (len(D)-2)): f[i]=D[:i](im_tensor.unsqueeze(0))
Next we'll create a method that computes the discriminator feature loss loss given a fake image:
def get_d_loss(f_x, fake, D):
loss_d=nn.L1Loss()(f_x[1], D[:1](fake)) #Get loss value from 1st layer of D
for i in range(2, (len(D)-2)):
loss_d+=nn.L1Loss()(f_x[i], D[:i](fake)) #And remainig layers of D
return loss_d
Finally, we need to balance our two loss terms (or reconsruction loss and discrminator loss). The anogan authors capture this in thier equation five:
$$ \mathcal{L}(z_\gamma) = (1-\lambda) \cdot \mathcal{L}_R(z_\gamma) + \lambda \cdot \mathcal{L}_D(z_\gamma). $$Where $\mathcal{L}(z_\gamma)$ is our overall loss function value, and $\lambda$ conrols the balance between our generator and discriminator loss. Let's put all of this together into a nice method we'll call walk_latent_space
. Using our new loss function, this method will seek to find the z that minimizes our loss.
def walk_latent_space(G, D, im_tensor, n_iter=1500, lambd=0.1, lr=2e-2, device='cuda'):
f_x={} #Precompute feature values all layers of D
with torch.no_grad():
for i in range(1, (len(D)-2)): f_x[i]=D[:i](im_tensor.unsqueeze(0))
z=torch.randn(1, 100, 1, 1, requires_grad=True, device=device) #random starting point for walk
opt=optim.Adam([z], lr=lr)
losses=[]
for i in tqdm(range(n_iter)):
fake=G(z) #Get fake image
loss_r=torch.nn.L1Loss()(fake.squeeze(0), im_tensor) #Residual loss
loss_d=get_d_loss(f_x, fake, D) #Discrimintator loss
loss=(1-lambd)*loss_r+lambd*loss_d #Total loss
loss.backward(); opt.step()
z.grad.zero_(); G.zero_grad(); D.zero_grad();
losses.append(loss.item())
return {'z':z, 'loss':loss.item(), 'loss_r':loss_r.item(), 'loss_d':loss_d.item(), 'losses':losses}
Let's try out our new walk_latent_space
on a good image:
im_path=(data_path/dset/'train'/'good').ls()[1]
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
im_tensor=((kornia.image_to_tensor(im).float()/128)-1).to(device)
res=walk_latent_space(G, D, im_tensor, n_iter=1500, lambd=0.1, lr=2e-2, device='cuda')
100%|██████████| 1500/1500 [00:21<00:00, 68.25it/s]
Let's see how well our optimizer is performing by plotting our loss function value as we take optimization steps:
plt.plot(res['losses']); plt.grid(1)
Not bad! We can also have a look at the trade-off between our reconsruction loss and discriminator loss. This will be imporant shortly.
res['loss'], res['loss_d'], res['loss_r']
(0.260686457157135, 1.2064778804779053, 0.15559853613376617)
Finally, let's see how we did! How does our synthetic image compar to our real one?
fig=plt.figure(0, (8,4))
fig.add_subplot(1,2,1)
plt.imshow(im); plt.title('Real Image')
fig.add_subplot(1,2,2)
plt.imshow((kornia.tensor_to_image(G(res['z']))+1)/2); plt.title('Closest Fake Image')
Text(0.5, 1.0, 'Closest Fake Image')
Not bad, right? Adding in our discriminator term really helped keep our image look realistic as we searched the latent space!
Now, we finally have all the pieces we need to fina anomalies using our GAN. To see how this works, let's grab an image with some anomalies.
im_path=Path('data/hazelnut/test/print/016.png')
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
plt.imshow(im)
<matplotlib.image.AxesImage at 0x7f1ea5610dd0>
As you can see, our image has "print" anomolies, where a mark has been printed on our poor hazelnut! Now, a big part of the anogan idea here is that anoalous images like this shouldn't really exists in the latent space of our Generator, since we only trained it on good images. So when we try to reconsturct our anamolous image, the reconstruction should fail in a very useful way. Ideally, the reconstruction should yeild "closest good image" - and this closest image is going to help us find anomalies! Let's try it.
im_tensor=((kornia.image_to_tensor(im).float()/128)-1).to(device)
res=walk_latent_space(G, D, im_tensor, n_iter=1500, lambd=0.1, lr=2e-2, device='cuda')
100%|██████████| 1500/1500 [00:22<00:00, 66.72it/s]
plt.plot(res['losses']); plt.grid(1)
Alright, let's have a look at our reconstructed image:
fig=plt.figure(0, (8,4))
fig.add_subplot(1,2,1)
plt.imshow(im); plt.title('Real Image')
fig.add_subplot(1,2,2)
plt.imshow((kornia.tensor_to_image(G(res['z']))+1)/2); plt.title('Closest Fake Image')
Text(0.5, 1.0, 'Closest Fake Image')
And as you can see, our reconstuction isn't great. Happily, this is exactly what we wanted! We shouldn't be able to reconstruct anomolous images. Now, if we can measure how poor our reconsruction is, we can use this metric to measure how anamolous our images are - cool, rigtht? And happily, as we dicussed above, we already have a couple metrics that capture the quality of our reconstruction, our reconsruction loss $\mathcal{L}_R(z_\gamma)$ and our discriminator loss $\mathcal{L}_D(z_\gamma)$. There idea here is that our anomolous images should have a higher reconsruction and discrimator loss. We explore this idea rigorously below. But first, let's discuss one more neat trick. Not only can we use GANs to identify anomolous images, we can also use GANs to identify anomolous regions in images. There idea here is that we can compute a "residual image" - the difference between our original and reconstructed image - and areas of the residual images with large values should correspond to anomalous regions, since the original and reconstructed image are maxiamlly different in these regions. Let's try it.
xr=np.abs((im.astype('float')/255)-(kornia.tensor_to_image(G(z))+1)/2) #Residual image
plt.imshow(np.mean(xr, axis=2))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f1ea539b310>
See the anomolous regions glowing? Pretty cool, right? We can take this a step further and threshold our redidual image to identify anomolous pixels:
thresh=0.5
plt.imshow(np.mean(xr, axis=2)>thresh)
<matplotlib.image.AxesImage at 0x7f1ea533a9d0>
Finally, let's put it all together;
im_mask=np.zeros_like(im)
for i in range(3): im_mask[:,:,i]=np.mean(im, axis=2)
im_mask[np.mean(xr, axis=2)>thresh]=im_mask[np.mean(xr, axis=2)>thresh]//2+np.array([255,0,255])//2
fig=plt.figure(0, (8,8))
fig.add_subplot(2,2,1); plt.imshow(im); plt.title('Image w/ Defect')
fig.add_subplot(2,2,2); plt.imshow((kornia.tensor_to_image(G(res['z']))+1)/2); plt.title('Reconstructed Image')
fig.add_subplot(2,2,3); plt.imshow(np.mean(xr, axis=2)); plt.title('Residual Image')
fig.add_subplot(2,2,4); plt.imshow(im_mask); plt.title('Anomalous Region')
Text(0.5, 1.0, 'Anomalous Region')
Alright, back to that idea of using our reconstruction and discriminator loss to identify anomolous images. Let's loop over all of our test iamges, and for each we'll created a reconstructed image, and compute our reconsruction and discriminator loss for each image. We'll then visualize these results to see if we can use these metrics to find anomolous images.
params={'n_iter':1500, 'lambd':0.1, 'lr':2e-2}
all_res=[]
im_paths=[p for p in (data_path/dset/'test').glob('*/*.png')]
for im_path in im_paths:
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
im_tensor=((kornia.image_to_tensor(im).float()/128)-1).to(device)
res=walk_latent_space(G, D, im_tensor, **params, device=device)
res['im_fake']=(kornia.tensor_to_image(G(res['z']))+1)/2 #Scaled between 0 and 1
res['im_path']=im_path
res['label']=im_path.parent.name
all_res.append(res)
100%|██████████| 1500/1500 [00:22<00:00, 66.97it/s] 100%|██████████| 1500/1500 [00:21<00:00, 69.50it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.28it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.07it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.47it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.73it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.79it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.14it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.54it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.01it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.87it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.29it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.47it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.63it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.15it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.72it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.49it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.41it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.60it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.52it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.41it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.16it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.47it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.93it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.22it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.76it/s] 100%|██████████| 1500/1500 [00:23<00:00, 62.90it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.81it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.28it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.85it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.70it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.44it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.87it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.60it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.62it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.97it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.80it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.04it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.21it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.78it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.05it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.09it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.19it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.14it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.79it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.03it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.80it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.92it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.50it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.68it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.68it/s] 100%|██████████| 1500/1500 [00:23<00:00, 65.22it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.76it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.04it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.93it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.68it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.82it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.80it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.80it/s] 100%|██████████| 1500/1500 [00:37<00:00, 40.28it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.53it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.58it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.43it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.79it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.75it/s] 100%|██████████| 1500/1500 [00:21<00:00, 68.76it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.89it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.28it/s] 100%|██████████| 1500/1500 [00:21<00:00, 68.54it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.26it/s] 100%|██████████| 1500/1500 [00:21<00:00, 68.87it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.77it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.76it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.37it/s] 100%|██████████| 1500/1500 [00:21<00:00, 68.19it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.77it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.34it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.23it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.34it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.82it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.25it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.89it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.84it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.15it/s] 100%|██████████| 1500/1500 [00:21<00:00, 68.40it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.93it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.92it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.71it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.46it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.30it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.28it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.03it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.87it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.15it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.50it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.91it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.49it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.02it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.86it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.89it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.62it/s] 100%|██████████| 1500/1500 [00:22<00:00, 67.65it/s] 100%|██████████| 1500/1500 [00:22<00:00, 66.81it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.95it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.77it/s] 100%|██████████| 1500/1500 [00:22<00:00, 65.46it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.36it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.79it/s] 100%|██████████| 1500/1500 [00:23<00:00, 63.76it/s] 100%|██████████| 1500/1500 [00:23<00:00, 64.37it/s]
#Cache results for later if needed
# with open(save_dir/'anogan-results-cache-1.p', 'wb') as f:
# pickle.dump(all_res, f, protocol=pickle.HIGHEST_PROTOCOL)
# Load previous results if needed
# with open(save_dir/'anogan-results-cache-1.p', 'rb') as f:
# all_res = pickle.load(f)
labels=np.unique([res['label'] for res in all_res])
cm={l:i for i,l in enumerate(labels)}
my_cmap={0:'r', 1:'b', 2:'g', 3:'c', 4:'m', 5:'y'}
my_markers={0:'o', 1:'x', 2:'^', 3:'<', 4:'>', 5:'*'}
fig=plt.figure(0, (8,8))
handles={}
for res in all_res:
handles[res['label']]=plt.scatter(res['loss_d'], res['loss_r'], c=my_cmap[cm[res['label']]], marker=my_markers[cm[res['label']]])
plt.legend(handles=handles.values(), labels=handles.keys(), fontsize=14); plt.grid(1)
plt.xlabel('Discriminator Loss', fontsize=14)
plt.ylabel('Residual Loss', fontsize=14)
Text(0, 0.5, 'Residual Loss')
Alright, this plot is really interesting! To be honest it didn't come at quite as cleanly as I expected, but this happens fairly frequently in ML, especially when experimenting with new types of algorithms like GANs. Since our generator should find it more difficult to reproduce anomlous images, we should see a higher reconstruction loss for our anomolous images, relative to our good images. As you can see, this isn't really true, with our good images having medium to high reconsruction losses. However, we do see a pretty reasonable/nice pattern with our discriminator - anomalous images tend to have higher discriminator losses. We can use this to create our very own anomalous image detector. We'll create a simple method that takes in an image, computed our reconsruction, and returns true if our discrimnator and reconsruction loss exceed set thresholds. Not you could do something more sophisticated here, such as multivariate logistic regression (or even another neural network!) to classify anomolous images in the space of our discriminator and residual loss - this may lead to overfitting and we would likeley want to introduce a proper test set, but could be interesting to experiment with.
def anomaly_detector(im, G, D, d_loss_thresh, r_loss_thresh):
'''Decides if image im is anomalous'''
im_tensor=((kornia.image_to_tensor(im).float()/128)-1).to(device)
res=walk_latent_space(G, D, im_tensor, n_iter=1500, lambd=0.1, lr=2e-2, device='cuda')
if res['loss_d']>d_loss_thresh or res['loss_r']>r_loss_thresh: return True
else: return False
Ok, let's test our our new function! We'll choose a d_loss_thresh
of 1.3 (just based on eye balling plot), and a r_loss_thresh
value of 0.21. Note that our redisual loss threshold isn't really doing much here, but may be useful for other types of images or other GAN weights. Let's measure the accuracy of our anomaly detector:
im_path=data_path/dset/'test'/'good'/'001.png' #Test on good example - should return False
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
predicted_anomaly=anomaly_detector(im, G, D, d_loss_thresh=1.3, r_loss_thresh=0.21)
print(predicted_anomaly)
100%|██████████| 1500/1500 [00:21<00:00, 68.70it/s]
False
im_path=data_path/dset/'test'/'crack'/'001.png' #Test on defect - should return True
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
predicted_anomaly=anomaly_detector(im, G, D, d_loss_thresh=1.3, r_loss_thresh=0.21)
print(predicted_anomaly)
100%|██████████| 1500/1500 [00:22<00:00, 66.02it/s]
True
Alright, 2 for 2! Now, let's see what our accuracy is accross all our data.
num_correct=0; total=0 #Measure accuracy on good test examples
for im_path in (data_path/dset/'test'/'good').glob('*.png'):
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
predicted_anomaly=anomaly_detector(im, G, D, d_loss_thresh=1.3, r_loss_thresh=0.21)
if not predicted_anomaly: num_correct+=1
total+=1
print(im_path, predicted_anomaly)
100%|██████████| 1500/1500 [00:21<00:00, 69.29it/s] 1%| | 8/1500 [00:00<00:20, 73.01it/s]
data/hazelnut/test/good/018.png True
100%|██████████| 1500/1500 [00:21<00:00, 68.86it/s] 0%| | 7/1500 [00:00<00:23, 64.20it/s]
data/hazelnut/test/good/034.png False
100%|██████████| 1500/1500 [00:21<00:00, 68.95it/s] 0%| | 6/1500 [00:00<00:25, 59.00it/s]
data/hazelnut/test/good/032.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.20it/s] 0%| | 7/1500 [00:00<00:23, 63.09it/s]
data/hazelnut/test/good/028.png False
100%|██████████| 1500/1500 [00:23<00:00, 63.81it/s] 0%| | 7/1500 [00:00<00:22, 67.16it/s]
data/hazelnut/test/good/012.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.67it/s] 0%| | 7/1500 [00:00<00:24, 60.37it/s]
data/hazelnut/test/good/033.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.20it/s] 0%| | 7/1500 [00:00<00:24, 62.01it/s]
data/hazelnut/test/good/022.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.66it/s] 0%| | 7/1500 [00:00<00:22, 67.22it/s]
data/hazelnut/test/good/000.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.50it/s] 0%| | 7/1500 [00:00<00:22, 67.17it/s]
data/hazelnut/test/good/008.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.21it/s] 0%| | 6/1500 [00:00<00:26, 57.08it/s]
data/hazelnut/test/good/021.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.99it/s] 0%| | 7/1500 [00:00<00:23, 63.88it/s]
data/hazelnut/test/good/006.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.25it/s] 0%| | 7/1500 [00:00<00:22, 66.99it/s]
data/hazelnut/test/good/026.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.86it/s] 1%| | 8/1500 [00:00<00:21, 70.32it/s]
data/hazelnut/test/good/001.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.36it/s] 0%| | 7/1500 [00:00<00:22, 65.26it/s]
data/hazelnut/test/good/014.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.29it/s] 0%| | 7/1500 [00:00<00:22, 66.27it/s]
data/hazelnut/test/good/038.png False
100%|██████████| 1500/1500 [00:23<00:00, 62.63it/s] 0%| | 6/1500 [00:00<00:25, 58.82it/s]
data/hazelnut/test/good/007.png False
100%|██████████| 1500/1500 [00:23<00:00, 63.72it/s] 0%| | 6/1500 [00:00<00:25, 59.05it/s]
data/hazelnut/test/good/039.png False
100%|██████████| 1500/1500 [00:24<00:00, 62.26it/s] 0%| | 7/1500 [00:00<00:23, 64.51it/s]
data/hazelnut/test/good/002.png True
100%|██████████| 1500/1500 [00:23<00:00, 63.80it/s] 0%| | 7/1500 [00:00<00:23, 63.72it/s]
data/hazelnut/test/good/005.png False
100%|██████████| 1500/1500 [00:23<00:00, 63.58it/s] 0%| | 7/1500 [00:00<00:23, 63.54it/s]
data/hazelnut/test/good/003.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.07it/s] 0%| | 7/1500 [00:00<00:23, 64.64it/s]
data/hazelnut/test/good/016.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.58it/s] 0%| | 7/1500 [00:00<00:22, 66.59it/s]
data/hazelnut/test/good/019.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.84it/s] 0%| | 7/1500 [00:00<00:23, 64.39it/s]
data/hazelnut/test/good/031.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.26it/s] 0%| | 7/1500 [00:00<00:23, 64.63it/s]
data/hazelnut/test/good/035.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.05it/s] 0%| | 7/1500 [00:00<00:22, 65.36it/s]
data/hazelnut/test/good/036.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.38it/s] 0%| | 7/1500 [00:00<00:22, 66.84it/s]
data/hazelnut/test/good/017.png False
100%|██████████| 1500/1500 [00:23<00:00, 63.91it/s] 0%| | 7/1500 [00:00<00:22, 67.44it/s]
data/hazelnut/test/good/013.png False
100%|██████████| 1500/1500 [00:23<00:00, 63.16it/s] 0%| | 7/1500 [00:00<00:23, 63.16it/s]
data/hazelnut/test/good/010.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.55it/s] 0%| | 7/1500 [00:00<00:23, 63.86it/s]
data/hazelnut/test/good/004.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.16it/s] 0%| | 7/1500 [00:00<00:21, 68.89it/s]
data/hazelnut/test/good/027.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.39it/s] 0%| | 7/1500 [00:00<00:23, 63.94it/s]
data/hazelnut/test/good/029.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.58it/s] 0%| | 7/1500 [00:00<00:22, 67.54it/s]
data/hazelnut/test/good/025.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.97it/s] 0%| | 7/1500 [00:00<00:23, 62.96it/s]
data/hazelnut/test/good/015.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.32it/s] 0%| | 7/1500 [00:00<00:22, 65.00it/s]
data/hazelnut/test/good/023.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.42it/s] 0%| | 7/1500 [00:00<00:25, 57.88it/s]
data/hazelnut/test/good/037.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.18it/s] 0%| | 7/1500 [00:00<00:21, 68.18it/s]
data/hazelnut/test/good/024.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.16it/s] 0%| | 7/1500 [00:00<00:22, 67.26it/s]
data/hazelnut/test/good/030.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.67it/s] 1%| | 8/1500 [00:00<00:20, 71.68it/s]
data/hazelnut/test/good/009.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.97it/s] 0%| | 6/1500 [00:00<00:25, 59.14it/s]
data/hazelnut/test/good/011.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.50it/s]
data/hazelnut/test/good/020.png False
num_correct/total
0.95
num_correct=0; total=0 #Measure accuracy on defective test examples
for im_path in (data_path/dset/'test').glob('*/*'):
if 'good' in str(im_path): continue
im=cv2.cvtColor(cv2.imread(str(im_path)), cv2.COLOR_BGR2RGB)
im=cv2.resize(im, (64,64))
predicted_anomaly=anomaly_detector(im, G, D, d_loss_thresh=1.3, r_loss_thresh=0.21)
if predicted_anomaly: num_correct+=1
total+=1
print(im_path, predicted_anomaly)
100%|██████████| 1500/1500 [00:23<00:00, 64.96it/s] 0%| | 7/1500 [00:00<00:21, 68.35it/s]
data/hazelnut/test/print/012.png True
100%|██████████| 1500/1500 [00:22<00:00, 67.50it/s] 0%| | 6/1500 [00:00<00:26, 55.86it/s]
data/hazelnut/test/print/000.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.08it/s] 0%| | 7/1500 [00:00<00:21, 69.38it/s]
data/hazelnut/test/print/008.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.89it/s] 0%| | 7/1500 [00:00<00:21, 69.51it/s]
data/hazelnut/test/print/006.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.73it/s] 0%| | 7/1500 [00:00<00:22, 66.09it/s]
data/hazelnut/test/print/001.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.11it/s] 0%| | 6/1500 [00:00<00:25, 58.06it/s]
data/hazelnut/test/print/014.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.30it/s] 0%| | 7/1500 [00:00<00:22, 66.73it/s]
data/hazelnut/test/print/007.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.34it/s] 0%| | 7/1500 [00:00<00:23, 63.03it/s]
data/hazelnut/test/print/002.png True
100%|██████████| 1500/1500 [00:22<00:00, 66.10it/s] 0%| | 7/1500 [00:00<00:22, 66.05it/s]
data/hazelnut/test/print/005.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.47it/s] 0%| | 7/1500 [00:00<00:23, 62.74it/s]
data/hazelnut/test/print/003.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.31it/s] 0%| | 7/1500 [00:00<00:24, 62.00it/s]
data/hazelnut/test/print/016.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.43it/s] 0%| | 7/1500 [00:00<00:22, 65.30it/s]
data/hazelnut/test/print/013.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.47it/s] 0%| | 7/1500 [00:00<00:24, 61.51it/s]
data/hazelnut/test/print/010.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.12it/s] 0%| | 7/1500 [00:00<00:21, 68.45it/s]
data/hazelnut/test/print/004.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.89it/s] 0%| | 6/1500 [00:00<00:25, 58.71it/s]
data/hazelnut/test/print/015.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.16it/s] 0%| | 7/1500 [00:00<00:22, 66.51it/s]
data/hazelnut/test/print/009.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.65it/s] 0%| | 7/1500 [00:00<00:22, 66.78it/s]
data/hazelnut/test/print/011.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.39it/s] 0%| | 7/1500 [00:00<00:22, 67.74it/s]
data/hazelnut/test/crack/012.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.37it/s] 0%| | 7/1500 [00:00<00:24, 61.87it/s]
data/hazelnut/test/crack/000.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.94it/s] 0%| | 7/1500 [00:00<00:24, 60.82it/s]
data/hazelnut/test/crack/008.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.60it/s] 0%| | 7/1500 [00:00<00:22, 66.69it/s]
data/hazelnut/test/crack/006.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.80it/s] 1%| | 8/1500 [00:00<00:21, 70.96it/s]
data/hazelnut/test/crack/001.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.65it/s] 0%| | 7/1500 [00:00<00:23, 64.46it/s]
data/hazelnut/test/crack/014.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.73it/s] 0%| | 7/1500 [00:00<00:22, 66.37it/s]
data/hazelnut/test/crack/007.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.74it/s] 0%| | 7/1500 [00:00<00:21, 68.04it/s]
data/hazelnut/test/crack/002.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.09it/s] 0%| | 7/1500 [00:00<00:23, 62.40it/s]
data/hazelnut/test/crack/005.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.04it/s] 0%| | 7/1500 [00:00<00:22, 65.68it/s]
data/hazelnut/test/crack/003.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.21it/s] 0%| | 7/1500 [00:00<00:21, 68.52it/s]
data/hazelnut/test/crack/016.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.55it/s] 0%| | 7/1500 [00:00<00:23, 63.41it/s]
data/hazelnut/test/crack/017.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.07it/s] 1%| | 8/1500 [00:00<00:21, 70.17it/s]
data/hazelnut/test/crack/013.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.92it/s] 1%| | 8/1500 [00:00<00:21, 70.89it/s]
data/hazelnut/test/crack/010.png True
100%|██████████| 1500/1500 [00:22<00:00, 66.34it/s] 0%| | 7/1500 [00:00<00:22, 66.74it/s]
data/hazelnut/test/crack/004.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.36it/s] 0%| | 7/1500 [00:00<00:23, 62.49it/s]
data/hazelnut/test/crack/015.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.31it/s] 0%| | 7/1500 [00:00<00:22, 67.42it/s]
data/hazelnut/test/crack/009.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.29it/s] 0%| | 7/1500 [00:00<00:24, 61.16it/s]
data/hazelnut/test/crack/011.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.26it/s] 0%| | 7/1500 [00:00<00:22, 65.65it/s]
data/hazelnut/test/hole/012.png True
100%|██████████| 1500/1500 [00:23<00:00, 65.07it/s] 1%| | 8/1500 [00:00<00:19, 74.79it/s]
data/hazelnut/test/hole/000.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.88it/s] 0%| | 7/1500 [00:00<00:21, 69.26it/s]
data/hazelnut/test/hole/008.png True
100%|██████████| 1500/1500 [00:22<00:00, 66.05it/s] 0%| | 7/1500 [00:00<00:21, 69.26it/s]
data/hazelnut/test/hole/006.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.60it/s] 0%| | 7/1500 [00:00<00:23, 63.29it/s]
data/hazelnut/test/hole/001.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.16it/s] 0%| | 7/1500 [00:00<00:22, 65.73it/s]
data/hazelnut/test/hole/014.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.55it/s] 1%| | 8/1500 [00:00<00:20, 72.29it/s]
data/hazelnut/test/hole/007.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.60it/s] 0%| | 7/1500 [00:00<00:22, 67.61it/s]
data/hazelnut/test/hole/002.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.47it/s] 1%| | 8/1500 [00:00<00:21, 70.62it/s]
data/hazelnut/test/hole/005.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.35it/s] 0%| | 7/1500 [00:00<00:23, 64.26it/s]
data/hazelnut/test/hole/003.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.76it/s] 1%| | 8/1500 [00:00<00:21, 70.60it/s]
data/hazelnut/test/hole/016.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.07it/s] 0%| | 7/1500 [00:00<00:23, 64.01it/s]
data/hazelnut/test/hole/017.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.35it/s] 0%| | 7/1500 [00:00<00:23, 63.17it/s]
data/hazelnut/test/hole/013.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.89it/s] 0%| | 7/1500 [00:00<00:23, 62.85it/s]
data/hazelnut/test/hole/010.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.36it/s] 0%| | 7/1500 [00:00<00:23, 64.81it/s]
data/hazelnut/test/hole/004.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.73it/s] 0%| | 7/1500 [00:00<00:24, 62.10it/s]
data/hazelnut/test/hole/015.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.14it/s] 0%| | 7/1500 [00:00<00:22, 67.68it/s]
data/hazelnut/test/hole/009.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.19it/s] 0%| | 7/1500 [00:00<00:21, 69.59it/s]
data/hazelnut/test/hole/011.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.57it/s] 1%| | 8/1500 [00:00<00:20, 71.73it/s]
data/hazelnut/test/cut/012.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.99it/s] 0%| | 7/1500 [00:00<00:21, 68.88it/s]
data/hazelnut/test/cut/000.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.93it/s] 0%| | 7/1500 [00:00<00:22, 67.42it/s]
data/hazelnut/test/cut/008.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.59it/s] 0%| | 7/1500 [00:00<00:23, 63.79it/s]
data/hazelnut/test/cut/006.png True
100%|██████████| 1500/1500 [00:21<00:00, 68.25it/s] 0%| | 7/1500 [00:00<00:22, 67.52it/s]
data/hazelnut/test/cut/001.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.95it/s] 0%| | 7/1500 [00:00<00:22, 65.00it/s]
data/hazelnut/test/cut/014.png False
100%|██████████| 1500/1500 [00:23<00:00, 65.20it/s] 1%| | 8/1500 [00:00<00:19, 74.84it/s]
data/hazelnut/test/cut/007.png True
100%|██████████| 1500/1500 [00:22<00:00, 65.36it/s] 0%| | 7/1500 [00:00<00:22, 66.56it/s]
data/hazelnut/test/cut/002.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.72it/s] 0%| | 7/1500 [00:00<00:23, 62.23it/s]
data/hazelnut/test/cut/005.png True
100%|██████████| 1500/1500 [00:23<00:00, 64.37it/s] 0%| | 7/1500 [00:00<00:22, 65.31it/s]
data/hazelnut/test/cut/003.png False
100%|██████████| 1500/1500 [00:23<00:00, 64.82it/s] 0%| | 7/1500 [00:00<00:23, 63.91it/s]
data/hazelnut/test/cut/016.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.27it/s] 0%| | 7/1500 [00:00<00:21, 68.87it/s]
data/hazelnut/test/cut/013.png False
100%|██████████| 1500/1500 [00:22<00:00, 66.01it/s] 1%| | 8/1500 [00:00<00:20, 72.93it/s]
data/hazelnut/test/cut/010.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.90it/s] 0%| | 7/1500 [00:00<00:23, 63.12it/s]
data/hazelnut/test/cut/004.png False
100%|██████████| 1500/1500 [00:22<00:00, 65.36it/s] 0%| | 7/1500 [00:00<00:21, 69.84it/s]
data/hazelnut/test/cut/015.png True
100%|██████████| 1500/1500 [00:22<00:00, 66.27it/s] 0%| | 7/1500 [00:00<00:21, 68.22it/s]
data/hazelnut/test/cut/009.png True
100%|██████████| 1500/1500 [00:22<00:00, 66.43it/s]
data/hazelnut/test/cut/011.png True
num_correct/total
0.6
For comparison, the MVTec AD paper sites an accruacy on hazelnuts using AnoGAN of 0.83 for the good examples, and 0.16 for the anamolous examples. We appear to be performing a bit better with an accuracy on good examples of around 98% and accuracy on anomalous examples of 68%. There are some differences between decisioning criteria, and our ad-hoc threshold based decisioning may not generalize terribly well, this would require more experimentation - but not a bad start!
Finally, if you would like to wrap up (and maybe deploy!) our anomaly detector, check out the anomaly_detector.py
script, you can run this script from the terminal like this:
python anomaly_detector.py --path_to_image PATH_TO_IMAGE --weights_dir PATH_TO_WEIGHTS --d_loss_thresh=1.3 --r_loss_thresh 0.21