! pip install -r https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt > /dev/null ! pip install git+https://github.com/openai/CLIP.git > /dev/null %matplotlib inline import matplotlib.pyplot as plt from glob import glob from PIL import Image import numpy as np import torch import clip import cv2 import os def objectDetection(img_path:str, model) -> list: image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) image_name = os.path.basename(img_path) image_name = image_name.split('.')[0] result = model(image) result.crop(save_dir=image_name) detectedObjects = result.render()[0] path = image_name + '/crops/**/*.jpg' listOfObjects = [] for filename in glob(path): obj = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB) listOfObjects.append(obj) return listOfObjects, detectedObjects def similarity_top(similarity_list:list, listOfObjects:list, N) -> tuple(): results = zip(range(len(similarity_list)), similarity_list) results = sorted(results, key=lambda x: x[1], reverse=True) images = [] scores=[] for index, score in results[:N]: scores.append(score) images.append(listOfObjects[index]) return scores, images def findObjects(listOfObjects:list, query:str, model, preprocess, device:str, N) -> tuple(): objects = torch.stack([preprocess(Image.fromarray(im)) for im in listOfObjects]).to(device) with torch.no_grad(): image_features = model.encode_image(objects) image_features /= image_features.norm(dim=-1, keepdim=True) text_features = model.encode_text(clip.tokenize(query).to(device)) text_features /= text_features.norm(dim=-1, keepdim=True) # Retrieve the description vector and the photo vectors # @: https://docs.python.org/3/whatsnew/3.5.html?highlight=operator#whatsnew-pep-465 similarity = (text_features.cpu().numpy() @ image_features.cpu().numpy().T) * 100 similarity = similarity[0] scores, images = similarity_top(similarity, listOfObjects, N=N) return scores, images def plotResults(scores:list, images:list, n): plt.figure(figsize=(20,5)) for index, img in enumerate(images): plt.subplot(1,n, index+1) plt.imshow(img) plt.title(scores[index]) plt.axis('off') plt.show(); OBJDETECTIONMODEL = 'yolov5x6' OBJDETECTIONREPO = 'ultralytics/yolov5' FINDERMODEL = 'ViT-B/32' DEVICE = 'cpu' N = 5 objectDetectorModel = torch.hub.load(OBJDETECTIONREPO, OBJDETECTIONMODEL); objectFinderModel , preProcess = clip.load(FINDERMODEL, device=DEVICE); # Available Models: # ['RN50', # 'RN101', # 'RN50x4', # 'RN50x16', # 'RN50x64', # 'ViT-B/32', # 'ViT-B/16', # 'ViT-L/14', # 'ViT-L/14@336px'] def pipeline(image, query): listOfObjects, detectedObjects = objectDetection(image, objectDetectorModel) scores, images = findObjects(listOfObjects, query, objectFinderModel, preProcess, DEVICE, N) plt.figure(figsize=(20,5)) plt.axis('off') plt.imshow(detectedObjects) plotResults(scores, images, N) pipeline('/content/1.jpg', 'clock') pipeline('/content/2.jpg', 'black backpack') pipeline('/content/3.jpg', 'man with a red shirt') pipeline('/content/3.jpg', 'woman with blue pants') pipeline('/content/5.jpg', 'wine glass') pipeline('/content/6.jpg', 'a big gray airplane') pipeline('/content/4.jpg', 'black bag')