Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Image similarity is the computer vision task of measuring similarity between two images, e.g. to build image retrieval systems. A simple but quite powerful approach is to use the CNNs trained for image classification to represent an image, and to define similarity between two images as the L2 distance between the respective DNN representations.
This is the approach taken in this notebook. More accurate similarity measures are based on DNNs which are trained explicitly for image similarity, such as the FaceNet work which uses a Siamese network architecture.
This notebook assumes you have a webcam connected to your machine. We use the ipywebrtc
module to show the webcam widget in the notebook. Currently, these widgets render correctly on Chrome browsers only. For more details about the widget, please visit ipywebrtc
github or doc.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
# Regular python libraries
import sys
import io
import os
from pathlib import Path
import time
# fast.ai
import fastai
from fastai.vision import (
cnn_learner, DatasetType, ImageList, imagenet_stats, models, open_image
)
from ipywebrtc import CameraStream, ImageRecorder
from ipywidgets import HBox, VBox, Label, Layout, widgets, Widget
import scrapbook as sb
# # Computer Vision repository
sys.path.extend(["..", "../.."]) # to access the utils_cv library
from utils_cv.classification.data import Urls
from utils_cv.classification.model import model_to_learner
from utils_cv.common.data import unzip_url
from utils_cv.common.gpu import which_processor, db_num_workers
from utils_cv.similarity.metrics import compute_distances
from utils_cv.similarity.model import compute_feature, compute_features_learner
from utils_cv.similarity.plot import plot_distances
print(f"Fast.ai: {fastai.__version__}")
which_processor()
Fast.ai: 1.0.48 Fast.ai (Torch) is using GPU: Tesla V100-PCIE-16GB
We show a simple example of how to build an image retrieval system. Given a query image, the goal is to find all similar images in a reference set. This is implemented using the following steps:
Here, we use a pre-trained ResNet model using the fastai.vision
package, alternatively the user can also bring their own trained models. For more details about image classification, including transfer-learning, please see the notebooks in the classification folder.
First, we specify the query image and the reference images.
IM_SIZE = 300 # image size in pixels. Reduce to speed-up demo.
# Set path to query and reference images
im_path = unzip_url(Urls.fridge_objects_path, exist_ok=True)
ref_im_path = os.path.join(Path(im_path) / "can")
query_im_path = os.path.join(Path(im_path) / "can" / "1.jpg")
print(f"Query image path = {query_im_path}")
print(f"Reference images directory = {ref_im_path}")
# Construct a DataBunch
ref_data = (
ImageList.from_folder(ref_im_path)
.split_none()
.label_from_folder()
.transform(tfms=None, size=IM_SIZE)
.databunch(bs=2, num_workers = db_num_workers())
.normalize(imagenet_stats)
)
Query image path = /data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/can/1.jpg Reference images directory = /data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/can
Next we load the model
# Load pre-trained model, alternatively one can load a custom model using fastai.vision.load_learner().
learn = cnn_learner(ref_data, models.resnet18, pretrained=True)
In this section we compute the embedding for each reference image and use the output of the penultimate layer to represent the image.
# Specify which DNN layer to use as embedding (here: penultimate layer)
embedding_layer = learn.model[1][-2]
ref_features = compute_features_learner(
ref_data, DatasetType.Train, learn, embedding_layer
)
print(
f"Computed DNN features for {len(list(ref_features))} images, \
each consisting of {len(ref_features[list(ref_features)[0]])} floating point values."
)
Computed DNN features for 32 images, each consisting of 512 floating point values.
First, we load a query image and find the most similar images within the reference set using the model.
query_im = open_image(query_im_path, convert_mode='RGB')
query_im
# Compute DNN feature for the query image
start_time = time.time()
query_feature = compute_feature(query_im, learn, embedding_layer)
print(f"Computing the DNN representation for the query image took {time.time()-start_time:.4f} sec")
Computing the DNN representation for the query image took 0.1155 sec
# Compute the distances between the query and all reference images
distances = compute_distances(query_feature, ref_features)
plot_distances(distances, num_rows=1, num_cols=7, figsize=(15,5))
Now, we use a WebCam stream for image retrieval. We use ipywebrtc
to start a webcam and get the video stream which is sent to the notebook's widget. Note that Jupyter widgets are quite unstable - if the widget below does not show then see the "Troubleshooting" section in this FAQ for possible fixes.
# Webcam
w_cam = CameraStream(
constraints={
"facing_mode": "user",
"audio": False,
"video": {"width": IM_SIZE, "height": IM_SIZE},
},
layout=Layout(width=f"{IM_SIZE}px"),
)
# Image recorder for taking a snapshot
w_imrecorder = ImageRecorder(stream=w_cam, layout=Layout(padding="0 0 0 50px"))
# Label widget to show our retrieval results
w_label = Label(layout=Layout(padding="0 0 0 50px"))
w_im = widgets.Image()
def classify_frame(_):
""" Classify an image snapshot by using a pretrained model
"""
# Once capturing started, remove the capture widget since we don't need it anymore
if w_imrecorder.layout.display != "none":
w_imrecorder.layout.display = "none"
try:
cam_im = open_image(io.BytesIO(w_imrecorder.image.value), convert_mode="RGB")
# Compute DNN representation for the webcam frame
featurizer.features = None
_, ind, prob = learn.predict(cam_im)
query_features = featurizer.features
# Retrieve most similar image among the reference images
similars = compute_topk_similar(query_features, ref_features)
im_path, distance = similars[0]
# Show result label and confidence
w_label.value = (
f"Most similar image with L2 distance of {distance:0.3f}: {im_path}"
)
w_im.value = open(im_path, "rb").read()
except OSError:
# If im_recorder doesn't have valid image data, skip it.
pass
# Taking the next snapshot programmatically
w_imrecorder.recording = True
# Register classify_frame as a callback. Will be called whenever image.value changes.
w_imrecorder.image.observe(classify_frame, "value")
# Show widgets
VBox([w_label, HBox([w_cam, w_imrecorder, w_im])])
VBox(children=(Label(value='', layout=Layout(padding='0 0 0 50px')), HBox(children=(CameraStream(constraints={…
Now, click the capture button in the widget to start image retrieval. The most similar reference image is shown along with the distance to the query image given by a webcam frame. Note that some browser either do not render the widget correctly, or do not show the widget at all. In our experience, Chrome seems to be the most reliable.
In this notebook, we used a simple example to demonstrate how to use a pretrained model to build an image retrieval system. The model is train on object labels that are part of the ImageNet training samples and hence will only work well for such objects. From our 01_training_and_evaluation_introduction notebook, you can find how to fine-tune and customize the model to also perform well for other objects that you may be interested in.
# Stop the model and webcam
Widget.close_all()
# Log some outputs which can be used during testing to verify correct notebook execution
sb.glue("query_feature", list(query_feature))
sb.glue("distances", distances)