%load_ext autoreload
%autoreload 2
import clip
from evaluation_utils import norm, denorm
from general_utils import *
from datasets.lvis_oneshot3 import LVIS_OneShot3
clip_device = 'cuda'
clip_model, preprocess = clip.load("ViT-B/16", device=clip_device)
clip_model.eval();
from models.clipseg import CLIPDensePredTMasked
clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)
clip_mask_model.eval();
lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False,
text_class_labels=True, image_size=352, min_area=0.1,
min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)
plot_data(lvis)
from collections import defaultdict
import json
lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))
lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))
objects_per_image = defaultdict(lambda : set())
for ann in lvis_raw['annotations']:
objects_per_image[ann['image_id']].add(ann['category_id'])
for ann in lvis_val_raw['annotations']:
objects_per_image[ann['image_id']].add(ann['category_id'])
objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}
del lvis_raw, lvis_val_raw
#bs = 32
#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]
from general_utils import get_batch
from functools import partial
from evaluation_utils import img_preprocess
import torch
def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):
# base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]
all_prompts = []
with torch.no_grad():
valid_sims = []
torch.manual_seed(571)
if type(batches_or_dataset) == list:
loader = batches_or_dataset # already loaded
max_iter = float('inf')
else:
loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)
max_iter = 50
global batch
for i_batch, (batch, batch_y) in enumerate(loader):
if i_batch >= max_iter: break
processed_batch = process(batch)
if type(processed_batch) == dict:
# processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}
image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()
else:
processed_batch = process(batch).to(clip_device)
processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')
#image_features = clip_model.encode_image(processed_batch.to(clip_device))
image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
bs = len(batch[0])
for j in range(bs):
c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]
support_image = basename(lvis.samples[c][sid])
img_objs = [o for o in objects_per_image[int(support_image)]]
img_objs = [o.replace('_', ' ') for o in img_objs]
other_words = [f'a photo of a {o.replace("_", " ")}' for o in img_objs
if o != batch_y[2][j]]
prompts = [f'a photo of a {batch_y[2][j]}'] + other_words
all_prompts += [prompts]
text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))
text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)
global logits
logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T
global sim
sim = torch.softmax(logits, dim=-1)
valid_sims += [sim]
#valid_sims = torch.stack(valid_sims)
return valid_sims, all_prompts
def new_img_preprocess(x):
return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}
#get_similarities(lvis, partial(img_preprocess, center_context=0.5));
get_similarities(lvis, lambda x: x[1]);
preprocessing_functions = [
# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],
# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],
# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],
# ['colorize object red', partial(img_preprocess, colorize=True)],
# ['add red outline', partial(img_preprocess, outline=True)],
# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],
# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],
# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],
# ['BG blur', partial(img_preprocess, blur=3)],
# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],
# ['crop large context', partial(img_preprocess, center_context=0.5)],
# ['crop small context', partial(img_preprocess, center_context=0.1)],
['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],
['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],
# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],
]
preprocessing_functions = preprocessing_functions
base, base_p = get_similarities(lvis, lambda x: x[1])
outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]
outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]
for j in range(1):
print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))
from pandas import DataFrame
tab = dict()
for j, (name, _) in enumerate(preprocessing_functions):
tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])
print('\n'.join(f'{k} & {v*100:.2f} \\\\' for k,v in tab.items()))
from evaluation_utils import denorm, norm
def load_sample(filename, filename2):
from os.path import join
bp = expanduser('~/cloud/resources/sample_images')
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize(224),
transforms.CenterCrop(224)
])
tf2 = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224),
transforms.CenterCrop(224)
])
inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]
inp1[1] = inp1[1].unsqueeze(0)
inp1[2] = inp1[2][:1]
return inp1
def all_preprocessing(inp1):
return [
img_preprocess(inp1),
img_preprocess(inp1, colorize=True),
img_preprocess(inp1, outline=True),
img_preprocess(inp1, blur=3),
img_preprocess(inp1, bg_fac=0.1),
#img_preprocess(inp1, bg_fac=0.5),
#img_preprocess(inp1, blur=3, bg_fac=0.5),
img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),
]
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
from evaluation_utils import img_preprocess
import clip
images_queries = [
[load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],
[load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],
]
_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))
for j, (images, objects) in enumerate(images_queries):
joint_image = all_preprocessing(images)
joint_image = torch.stack(joint_image)[:,0]
clip_model, preprocess = clip.load("ViT-B/16", device='cpu')
image_features = clip_model.encode_image(joint_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
prompts = [f'a photo of a {obj}'for obj in objects]
text_cond = clip_model.encode_text(clip.tokenize(prompts))
text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)
logits = clip_model.logit_scale.exp() * image_features @ text_cond.T
sim = torch.softmax(logits, dim=-1).detach().cpu()
for i, img in enumerate(joint_image):
ax[2*j, i].axis('off')
ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))
ax[2*j+ 1, i].grid(True)
ax[2*j + 1, i].set_ylim(0,1)
ax[2*j + 1, i].set_yticklabels([])
ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))
# ax[1, i].set_xticklabels(objects, rotation=90)
for k in range(len(sim[i])):
ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))
ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)
plt.tight_layout()
plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')