#!/usr/bin/env python # coding: utf-8 # # Systematic # In[ ]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') import clip from evaluation_utils import norm, denorm from general_utils import * from datasets.lvis_oneshot3 import LVIS_OneShot3 clip_device = 'cuda' clip_model, preprocess = clip.load("ViT-B/16", device=clip_device) clip_model.eval(); from models.clipseg import CLIPDensePredTMasked clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device) clip_mask_model.eval(); # In[ ]: lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, text_class_labels=True, image_size=352, min_area=0.1, min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True) # In[ ]: plot_data(lvis) # In[ ]: from collections import defaultdict import json lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json'))) lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json'))) objects_per_image = defaultdict(lambda : set()) for ann in lvis_raw['annotations']: objects_per_image[ann['image_id']].add(ann['category_id']) for ann in lvis_val_raw['annotations']: objects_per_image[ann['image_id']].add(ann['category_id']) objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()} del lvis_raw, lvis_val_raw # In[ ]: #bs = 32 #batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)] # In[ ]: from general_utils import get_batch from functools import partial from evaluation_utils import img_preprocess import torch def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False): # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']] all_prompts = [] with torch.no_grad(): valid_sims = [] torch.manual_seed(571) if type(batches_or_dataset) == list: loader = batches_or_dataset # already loaded max_iter = float('inf') else: loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32) max_iter = 50 global batch for i_batch, (batch, batch_y) in enumerate(loader): if i_batch >= max_iter: break processed_batch = process(batch) if type(processed_batch) == dict: # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()} image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half() else: processed_batch = process(batch).to(clip_device) processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear') #image_features = clip_model.encode_image(processed_batch.to(clip_device)) image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half() image_features = image_features / image_features.norm(dim=-1, keepdim=True) bs = len(batch[0]) for j in range(bs): c, _, sid, qid = lvis.sample_ids[bs * i_batch + j] support_image = basename(lvis.samples[c][sid]) img_objs = [o for o in objects_per_image[int(support_image)]] img_objs = [o.replace('_', ' ') for o in img_objs] other_words = [f'a photo of a {o.replace("_", " ")}' for o in img_objs if o != batch_y[2][j]] prompts = [f'a photo of a {batch_y[2][j]}'] + other_words all_prompts += [prompts] text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device)) text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) global logits logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T global sim sim = torch.softmax(logits, dim=-1) valid_sims += [sim] #valid_sims = torch.stack(valid_sims) return valid_sims, all_prompts def new_img_preprocess(x): return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])} #get_similarities(lvis, partial(img_preprocess, center_context=0.5)); get_similarities(lvis, lambda x: x[1]); # In[ ]: preprocessing_functions = [ # ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}], # ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}], # ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}], # ['colorize object red', partial(img_preprocess, colorize=True)], # ['add red outline', partial(img_preprocess, outline=True)], # ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)], # ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)], # ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)], # ['BG blur', partial(img_preprocess, blur=3)], # ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)], # ['crop large context', partial(img_preprocess, center_context=0.5)], # ['crop small context', partial(img_preprocess, center_context=0.1)], ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)], ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)], # ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)], ] preprocessing_functions = preprocessing_functions base, base_p = get_similarities(lvis, lambda x: x[1]) outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions] # In[ ]: outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]] # In[ ]: for j in range(1): print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])) # In[ ]: from pandas import DataFrame tab = dict() for j, (name, _) in enumerate(preprocessing_functions): tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]) print('\n'.join(f'{k} & {v*100:.2f} \\\\' for k,v in tab.items())) # # Visual # In[ ]: from evaluation_utils import denorm, norm # In[ ]: def load_sample(filename, filename2): from os.path import join bp = expanduser('~/cloud/resources/sample_images') tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize(224), transforms.CenterCrop(224) ]) tf2 = transforms.Compose([ transforms.ToTensor(), transforms.Resize(224), transforms.CenterCrop(224) ]) inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))] inp1[1] = inp1[1].unsqueeze(0) inp1[2] = inp1[2][:1] return inp1 def all_preprocessing(inp1): return [ img_preprocess(inp1), img_preprocess(inp1, colorize=True), img_preprocess(inp1, outline=True), img_preprocess(inp1, blur=3), img_preprocess(inp1, bg_fac=0.1), #img_preprocess(inp1, bg_fac=0.5), #img_preprocess(inp1, blur=3, bg_fac=0.5), img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5), ] # In[ ]: from torchvision import transforms from PIL import Image from matplotlib import pyplot as plt from evaluation_utils import img_preprocess import clip images_queries = [ [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']], [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']], ] _, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries))) for j, (images, objects) in enumerate(images_queries): joint_image = all_preprocessing(images) joint_image = torch.stack(joint_image)[:,0] clip_model, preprocess = clip.load("ViT-B/16", device='cpu') image_features = clip_model.encode_image(joint_image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) prompts = [f'a photo of a {obj}'for obj in objects] text_cond = clip_model.encode_text(clip.tokenize(prompts)) text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) logits = clip_model.logit_scale.exp() * image_features @ text_cond.T sim = torch.softmax(logits, dim=-1).detach().cpu() for i, img in enumerate(joint_image): ax[2*j, i].axis('off') ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1)) ax[2*j+ 1, i].grid(True) ax[2*j + 1, i].set_ylim(0,1) ax[2*j + 1, i].set_yticklabels([]) ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts))) # ax[1, i].set_xticklabels(objects, rotation=90) for k in range(len(sim[i])): ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3)) ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15) plt.tight_layout() plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')