%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
torch.backends.cudnn.benchmark = True
import fastText as ft
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
%mkdir data/imagenet
%cd data/imagenet/
/home/ubuntu/data/imagenet
!aria2c --file-allocation=none -c -x 5 -s 5 http://files.fast.ai/data/imagenet-sample-train.tar.gz
*** Download Progress Summary as of Thu Jul 5 15:24:15 2018 *** =============================================================================== [#67c134 1.8GiB/2.0GiB(90%) CN:5 DL:28MiB ETA:6s] FILE: /home/ubuntu/data/imagenet/imagenet-sample-train.tar.gz ------------------------------------------------------------------------------- [#67c134 2.0GiB/2.0GiB(99%) CN:2 DL:27MiB] 07/05 15:24:23 [NOTICE] Download complete: /home/ubuntu/data/imagenet/imagenet-sample-train.tar.gz Download Results: gid |stat|avg speed |path/URI ======+====+===========+======================================================= 67c134|OK | 31MiB/s|/home/ubuntu/data/imagenet/imagenet-sample-train.tar.gz Status Legend: (OK):download completed.
!tar -xzf imagenet-sample-train.tar.gz
%ls -la train/ | head -n 10
total 3128 drwxrwxr-x 776 ubuntu ubuntu 24576 Jan 19 2017 ./ drwxrwxr-x 3 ubuntu ubuntu 4096 Jul 5 15:27 ../ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01440764/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01443537/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01491361/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01494475/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01498041/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01514668/ drwxrwxr-x 2 ubuntu ubuntu 4096 Nov 16 2016 n01514859/ ls: write error
%cd ../..
/home/ubuntu
fname = 'n01440764/n01440764_9649.JPEG'
PATH = Path('data/imagenet/')
TMP_PATH = PATH / 'tmp'
TRANS_PATH = Path('data/translate/') # for fastText word vectors
PATH_TRN = PATH / 'train'
img = Image.open(PATH_TRN / fname)
img
Data pipeline
import fastai
fastai.dataloader.DataLoader
fastai.dataloader.DataLoader
arch = resnet50
ttfms, vtfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
def to_array(x, y):
return np.array(x).astype(np.float32) / 255, None
def TT(x, y):
return torch.from_numpy(x), None
ttfms.tfms = [to_array] + ttfms.tfms# + [TT]
ttfms(img)
array([[[ 0.404 , 0.88463, 0.64182, ..., 1.0088 , 0.77503, 1.89105], [ 0.85854, 1.58429, 1.3317 , ..., 0.89556, 0.72348, 2.00313], [ 0.88372, 1.01779, 0.83366, ..., 0.8472 , 0.74346, 2.149 ], ..., [-0.00787, 0.28974, 0.65969, ..., 0.52232, 0.59659, 0.68792], [-0.03541, 0.33924, 0.60903, ..., 0.60172, 0.66966, 0.78727], [-0.14989, 0.30909, 0.56169, ..., 0.39474, 0.54049, 0.76156]], [[ 0.70536, 1.2093 , 0.97274, ..., 1.19664, 0.95415, 2.06959], [ 1.1683 , 1.91048, 1.66838, ..., 1.08087, 0.90103, 2.18123], [ 1.1923 , 1.33133, 1.1534 , ..., 1.03143, 0.92103, 2.32741], ..., [ 0.58756, 0.90526, 1.31344, ..., 0.92061, 0.97329, 1.05188], [ 0.5594 , 0.95923, 1.26354, ..., 1.00818, 1.06743, 1.16525], [ 0.44237, 0.93176, 1.21704, ..., 0.80195, 0.95095, 1.14759]], [[ 0.42472, 0.90539, 0.64079, ..., 1.0115 , 0.75108, 1.83496], [ 0.89957, 1.61148, 1.32816, ..., 0.88833, 0.68864, 1.93359], [ 0.92079, 1.00317, 0.79638, ..., 0.83988, 0.71185, 2.07995], ..., [ 0.17452, 0.47868, 0.84404, ..., -0.05322, 0.00694, 0.09008], [ 0.14774, 0.52738, 0.79248, ..., 0.03184, 0.09421, 0.19903], [ 0.03206, 0.49502, 0.7443 , ..., -0.18906, -0.03814, 0.16826]]], dtype=float32)
Load the Word Vectors
# fastText word vectors
ft_vecs = ft.load_model(str((TRANS_PATH / 'wiki.en.bin')))
ft_vecs.get_word_vector('king') # returns numpy.ndarray of shape (300,)
array([ 0.03259, -0.18164, -0.29049, -0.10506, -0.16712, -0.07748, -0.5661 , -0.08622, -0.00216, 0.15366, 0.12189, -0.14722, 0.01511, 0.07209, -0.02156, -0.20612, -0.02104, -0.01999, -0.15506, 0.00802, -0.22746, 0.33518, -0.10629, -0.50318, -0.1582 , 0.27829, 0.05752, -0.32697, 0.04766, 0.01076, 0.13972, -0.12445, -0.18989, 0.32969, -0.32513, 0.10958, 0.21962, -0.47215, 0.03422, -0.2207 , 0.02177, 0.0832 , -0.04776, -0.48873, 0.05207, -0.15001, -0.19203, 0.06177, 0.15535, -0.05598, 0.11071, 0.39161, -0.17716, 0.05449, 0.25898, -0.13954, 0.4272 , -0.07273, -0.4714 , 0.04993, 0.29526, -0.05319, 0.03451, -0.10583, -0.30137, 0.16372, 0.07541, 0.21018, -0.11459, 0.10976, 0.04923, 0.17688, 0.45658, -0.59762, -0.0039 , 0.08866, 0.53103, 0.153 , -0.1673 , 0.13121, -0.05547, -0.03582, -0.34535, 0.09128, 0.03323, 0.45211, -0.16894, 0.21139, 0.24153, 0.51014, -0.01474, -0.47179, 0.2235 , -0.34668, 0.12126, 0.23727, -0.08424, 0.04555, -0.07698, 0.0428 , -0.13887, 0.29286, -0.28864, 0.53446, 0.02677, -0.04119, 0.40156, 0.38334, 0.01935, 0.02089, 0.02142, -0.11958, -0.44997, 0.13685, -0.12185, -0.00509, 0.60342, 0.65889, -0.16251, 0.46393, 0.19732, 0.19346, -0.07765, 0.17387, 0.07279, 0.04365, -0.01246, 0.4392 , 0.03182, 0.34927, -0.13155, 0.41265, 0.1348 , 0.03162, 0.17821, 0.20899, -0.03224, -0.37799, 0.23646, 0.10512, -0.00483, 0.33617, 0.43214, 0.28264, 0.01725, 0.35155, 0.28504, -0.41468, -0.20859, 0.08935, -0.08568, -0.3982 , -0.61611, 0.574 , -0.34191, 0.03569, 0.08309, 0.02758, 0.30767, -0.14426, -0.23718, 0.19269, 0.12444, 0.20298, -0.08636, -0.30212, 0.06119, 0.08865, 0.60565, 0.23092, -0.16018, -0.44802, -0.14103, 0.08389, 0.08604, 0.17387, -0.11659, 0.15751, -0.25178, 0.12577, 0.28713, -0.00183, 0.05259, -0.0495 , -0.03082, 0.13133, -0.00867, 0.00691, 0.30406, 0.18153, -0.05479, -0.39295, 0.29229, 0.27204, 0.01185, 0.02325, 0.02535, -0.21103, -0.45489, 0.10004, 0.26659, -0.12585, -0.03636, -0.1304 , -0.10385, -0.35109, -0.04138, 0.20202, 0.08724, -0.22088, 0.25375, 0.08034, 0.0022 , -0.14621, -0.16164, 0.12694, -0.01651, -0.11299, -0.06235, 0.15739, -0.20588, -0.09687, -0.22731, -0.10299, -0.02208, 0.1705 , -0.41714, 0.13382, -0.09988, -0.35683, 0.49678, -0.00604, -0.09917, 0.28355, 0.27951, 0.09213, 0.12555, 0.12955, 0.05188, -0.14202, -0.18416, -0.48024, -0.02423, 0.10908, -0.04117, -0.20895, -0.30235, 0.47612, -0.22305, -0.41871, -0.03084, 0.02981, 0.21836, -0.04544, -0.24222, 0.0735 , -0.16438, -0.05721, 0.31028, 0.26954, 0.20621, 0.04835, 0.10146, -0.2655 , 0.00589, -0.0269 , 0.05519, 0.2096 , -0.21835, 0.12025, -0.44548, 0.05322, -0.23166, 0.03323, 0.13661, -0.39058, 0.1834 , 0.01626, -0.19765, 0.14757, -0.06413, 0.34661, 0.31601, 0.13334, -0.53255, 0.26908, 0.27234, -0.1101 , -0.11572, -0.42586, 0.21509, -0.23383, 0.07461, 0.30356, 0.0955 , -0.30532, -0.2858 , 0.27764, 0.04028, -0.09576], dtype=float32)
np.corrcoef( ft_vecs.get_word_vector('jeremy'), ft_vecs.get_word_vector('Jeremy') )
array([[1. , 0.60866], [0.60866, 1. ]])
np.corrcoef(ft_vecs.get_word_vector('banana'), ft_vecs.get_word_vector('Jeremy'))
array([[1. , 0.14482], [0.14482, 1. ]])
ft_words = ft_vecs.get_words(include_freq=True)
ft_word_dict = { k: v for k, v in zip(*ft_words) }
ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x])
len(ft_words)
2519370
Get ImageNet classes
from fastai.io import get_data
CLASSES_FN = 'imagenet_class_index.json'
get_data(f'http://files.fast.ai/models/{CLASSES_FN}', TMP_PATH / CLASSES_FN)
imagenet_class_index.json: 41.0kB [00:00, 54.7kB/s]
Get all nouns in English (Wordnet)
WORDS_FN = 'classids.txt'
get_data(f'http://files.fast.ai/data/{WORDS_FN}', PATH / WORDS_FN)
classids.txt: 1.74MB [00:02, 816kB/s]
Create ImageNet class number to words.
class_dict = json.load((TMP_PATH / CLASSES_FN).open())
classids_1k = dict(class_dict.values())
nclass = len(class_dict)
nclass
1000
class_dict['0']
['n01440764', 'tench']
Wordnet class number to nouns
classid_lines = (PATH / WORDS_FN).open().readlines()
classid_lines[:5]
['n00001740 entity\n', 'n00001930 physical_entity\n', 'n00002137 abstraction\n', 'n00002452 thing\n', 'n00002684 object\n']
classids = dict( l.strip().split() for l in classid_lines )
len(classids), len(classids_1k)
(82115, 1000)
Look up all the nouns in fastText
lc_vec_d = { w.lower(): ft_vecs.get_word_vector(w) for w in ft_words[-1000000:] }
syn_wv = [(k, lc_vec_d[v.lower()]) for k, v in classids.items()
if v.lower() in lc_vec_d]
syn_wv_1k = [(k, lc_vec_d[v.lower()]) for k, v in classids_1k.items()
if v.lower() in lc_vec_d]
syn2wv = dict(syn_wv)
len(syn2wv)
49469
Save the lookups
pickle.dump(syn2wv, (TMP_PATH / 'syn2wv.pkl').open('wb'))
pickle.dump(syn_wv_1k, (TMP_PATH / 'syn_wv_1k.pkl').open('wb'))
syn2wv = pickle.load((TMP_PATH / 'syn2wv.pkl').open('rb'))
syn_wv_1k = pickle.load((TMP_PATH / 'syn_wv_1k.pkl').open('rb'))
images = []
img_vecs = []
for d in (PATH / 'train').iterdir():
if d.name not in syn2wv:
continue
vec = syn2wv[d.name]
for f in d.iterdir():
images.append(str(f.relative_to(PATH)))
img_vecs.append(vec)
# n_val=0
# for d in (PATH/'valid').iterdir():
# if d.name not in syn2wv: continue
# vec = syn2wv[d.name]
# for f in d.iterdir():
# images.append(str(f.relative_to(PATH)))
# img_vecs.append(vec)
# n_val += 1
n_val = 3888 # hardcode. split data into train and validation set. validation set is 20% of train set.
n = len(images)
n
14048
14048
val_idxs = list(range(n - n_val, n))
len(val_idxs), val_idxs[0]
(3888, 10160)
#DEBUG
print(images[0])
print(img_vecs[0].shape)
print(len(img_vecs))
train/n03337140/n03337140_3342.JPEG (300,) 14048
img_vecs = np.stack(img_vecs)
img_vecs.shape
(14048, 300)
pickle.dump(images, (TMP_PATH / 'images.pkl').open('wb'))
pickle.dump(img_vecs, (TMP_PATH / 'img_vecs.pkl').open('wb'))
images = pickle.load((TMP_PATH / 'images.pkl').open('rb'))
img_vecs = pickle.load((TMP_PATH / 'img_vecs.pkl').open('rb'))
================================ START DEBUG ================================
#DEBUG
!find data/imagenet/train -type f | wc -l
19439
#DEBUG
!ls data/imagenet/train/n01440764
n01440764_10365.JPEG n01440764_2006.JPEG n01440764_8013.JPEG n01440764_11155.JPEG n01440764_2418.JPEG n01440764_8063.JPEG n01440764_11787.JPEG n01440764_26631.JPEG n01440764_8426.JPEG n01440764_12241.JPEG n01440764_3281.JPEG n01440764_853.JPEG n01440764_12732.JPEG n01440764_3421.JPEG n01440764_8938.JPEG n01440764_13275.JPEG n01440764_4934.JPEG n01440764_9567.JPEG n01440764_14405.JPEG n01440764_529.JPEG n01440764_9649.JPEG n01440764_1713.JPEG n01440764_63.JPEG n01440764_188.JPEG n01440764_6878.JPEG
#DEBUG
# new_file = 'data/imagenet/train/n01498041/n01498041_10412.JPEG'
new_file = 'data/imagenet/train/n01440764/n01440764_6878.JPEG'
new_im = Image.open(new_file).resize((224,224), Image.BILINEAR)
new_im
================================ END DEBUG ================================
Create the model architecture + datasets
arch = resnet50
# transformers of images for training
tfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
# we can pass all the names from imagenet + word vecs
# then pass the indexes
# continuous = True - since we are predicting vectors
md = ImageClassifierData.from_names_and_array(PATH, images, img_vecs, val_idxs=val_idxs,
classes=None, tfms=tfms, continuous=True, bs=256)
x, y = next(iter(md.val_dl))
"""
md.c - pass number of classes to the c argument (size of the last layer)
is_multi - not multiclass
is_reg - is regression
xtra_fc - extra fully connected layers
ps - how much dropout do you want?
*note no softmax
"""
models = ConvnetBuilder(arch, md.c, is_multi=False, is_reg=True, xtra_fc=[1024], ps=[0.2, 0.2])
learn = ConvLearner(md, models, precompute=True)
learn.opt_fn = partial(optim.Adam, betas=(0.9, 0.99))
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /home/ubuntu/.torch/models/resnet50-19c8e357.pth 100%|██████████| 102502400/102502400 [00:01<00:00, 72565393.28it/s]
100%|██████████| 40/40 [01:31<00:00, 2.28s/it] 100%|██████████| 16/16 [00:36<00:00, 2.28s/it]
# loss function - L1 loss is the difference
# but since we are doing high-dimensional vectors, most of the items
# are on the outside and the distance metric isn't the best metric
def cos_loss(inp, targ):
return 1 - F.cosine_similarity(inp, targ).mean()
learn.crit = cos_loss
Train the model
Training duration ~1 hour.
learn.lr_find(start_lr=1e-4, end_lr=1e15)
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
85%|████████▌ | 34/40 [00:02<00:00, 13.45it/s, loss=nan]
learn.sched.plot()
lr = 1e-2
wd = 1e-7
# train with precompute = True to cut down on training time
learn.precompute = True
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20, 10))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
epoch trn_loss val_loss 0 0.533586 0.470473 1 0.372923 0.486955 2 0.293371 0.49963 3 0.236202 0.505895 4 0.195004 0.510554 5 0.165844 0.516996 6 0.144815 0.530448 7 0.129941 0.523714 8 0.117989 0.525584 9 0.109467 0.523132 10 0.102434 0.526665 11 0.09594 0.528045 12 0.090793 0.525027 13 0.08635 0.530179 14 0.082674 0.52541 15 0.078416 0.524496 16 0.07525 0.529237 17 0.072656 0.527995 18 0.070164 0.527018 19 0.068064 0.528724
[array([0.52872])]
learn.bn_freeze(True)
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20, 10))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
epoch trn_loss val_loss 0 0.055475 0.533504 1 0.061694 0.543637 2 0.069302 0.537233 3 0.066792 0.538912 4 0.059769 0.534378 5 0.053277 0.531469 6 0.048054 0.533863 7 0.043353 0.534298 8 0.039795 0.538832 9 0.036677 0.538117 10 0.033617 0.546751 11 0.031627 0.539823 12 0.029719 0.530515 13 0.027769 0.547381 14 0.025036 0.548819 15 0.023828 0.538898 16 0.022615 0.535674 17 0.021772 0.535489 18 0.020845 0.544093 19 0.020268 0.545169
[array([0.54517])]
lrs = np.array([lr / 1000, lr / 100, lr])
learn.precompute = False
learn.freeze_to(1)
learn.save('pre0')
learn.load('pre0')
# syn_wv_1k is ImageNet 1000 classes (syn) mapped to fastText word vectors
syns, wvs = list(zip(*syn_wv_1k)) # split tuple of syn_id and word vector into 2 list, syn_ids, word vectors
wvs = np.array(wvs)
# DEBUG
syn_wv_1k[0][0], syn_wv_1k[0][1][:10]
('n01440764', array([ 0.01299, 0.51545, -0.02986, -0.17743, -0.13517, 0.09963, 0.15457, -0.29894, 0.06537, -0.32881], dtype=float32))
# DEBUG
syns[0], wvs[0][:10]
('n01440764', array([ 0.01299, 0.51545, -0.02986, -0.17743, -0.13517, 0.09963, 0.15457, -0.29894, 0.06537, -0.32881], dtype=float32))
%time pred_wv = learn.predict()
CPU times: user 1min 8s, sys: 4.99 s, total: 1min 13s Wall time: 33.2 s
Let's take a look at some of the pictures
start = 512
denorm = md.val_ds.denorm
def show_img(im, figsize=None, ax=None):
if not ax:
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(im)
ax.axis('off')
return ax
def show_imgs(ims, cols, figsize=None):
fig, axes = plt.subplots(len(ims) // cols, cols, figsize=figsize)
for i, ax in enumerate(axes.flat):
show_img(ims[i], ax=ax)
plt.tight_layout()
show_imgs(denorm(md.val_ds[start:start + 25][0]), 5, (10, 10))
Search 300D vector, what are the closest neighbors?
There is an amazing almost unknown library called NMSLib that does that incredibly fast. The library uses multi-threading and is absolutely fantastic. You can install from pip (pip install nmslib
) and it just works.
!pip install nmslib
Collecting nmslib
Downloading https://files.pythonhosted.org/packages/de/eb/28b2060bb1750426c5618e3ad6ce830ac3cfd56cb3eccfb799e52d6064db/nmslib-1.7.2.tar.gz (254kB)
100% |████████████████████████████████| 256kB 1.3MB/s ta 0:00:01
Requirement already satisfied: pybind11>=2.0 in ./anaconda3/envs/fastai/lib/python3.6/site-packages (from nmslib)
Requirement already satisfied: numpy in ./anaconda3/envs/fastai/lib/python3.6/site-packages (from nmslib)
Building wheels for collected packages: nmslib
Running setup.py bdist_wheel for nmslib ... done
Stored in directory: /home/ubuntu/.cache/pip/wheels/01/5d/03/201cc23dfe226021fd08f1eac9d03df73473eed25ed4e557c7
Successfully built nmslib
Installing collected packages: nmslib
Successfully installed nmslib-1.7.2
You are using pip version 9.0.3, however version 10.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
import nmslib
def create_index(a):
index = nmslib.init(space='angulardist')
index.addDataPointBatch(a)
index.createIndex()
return index
def get_knns(index, vecs):
return zip(*index.knnQueryBatch(vecs, k=10, num_threads=4))
def get_knn(index, vec):
return index.knnQuery(vec, k=10)
nn_wvs = create_index(wvs)
idxs, dists = get_knns(nn_wvs, pred_wv)
# DEBUG
wn_p = list(zip(list(classids.keys()), list(classids.values())))[100:105]
print(wn_p)
syns[0]
[('n00045646', 'rally'), ('n00045907', 'recovery'), ('n00046177', 'running_away'), ('n00046344', 'stunt'), ('n00046522', 'touch')]
'n01440764'
# ImageNet classes
[ [classids[syns[id]] for id in ids[:3]] for ids in idxs[start:start + 10] ]
[['mink', 'polecat', 'cougar'], ['badger', 'polecat', 'otter'], ['marmot', 'badger', 'polecat'], ['marmot', 'badger', 'mink'], ['polecat', 'badger', 'skunk'], ['mink', 'polecat', 'beaver'], ['polecat', 'cougar', 'badger'], ['dingo', 'wombat', 'polecat'], ['cockroach', 'bathtub', 'plunger'], ['polecat', 'skunk', 'mink']]
What if we now bring in WordNet.
# DEBUG
# syn2wv is of type dict
syn2wv['n00001740'][:10] # returns 300 dimensional word vector
array([ 0.02561, 0.17057, -0.12382, 0.3527 , -0.06303, 0.08731, 0.14308, -0.32462, -0.31296, 0.09208], dtype=float32)
all_syns, all_wvs = list(zip(*syn2wv.items()))
all_wvs = np.array(all_wvs)
# DEBUG
all_syns[0], all_wvs[0][:10]
('n00001740', array([ 0.02561, 0.17057, -0.12382, 0.3527 , -0.06303, 0.08731, 0.14308, -0.32462, -0.31296, 0.09208], dtype=float32))
# nearest neigbour for all word vectors for WordNet noun classes
nn_allwvs = create_index(all_wvs)
idxs, dists = get_knns(nn_allwvs, pred_wv)
[ [classids[all_syns[id]] for id in ids[:3]] for ids in idxs[start:start + 10] ]
[['mink', 'mink', 'mink'], ['badger', 'polecat', 'raccoon'], ['marmot', 'Marmota', 'badger'], ['marmot', 'Marmota', 'badger'], ['polecat', 'Mustela', 'stoat'], ['mink', 'mink', 'mink'], ['polecat', 'Mustela', 'cougar'], ['dog', 'dog', 'alligator'], ['nosepiece', 'sweatband', 'sweatband'], ['polecat', 'Mustela', 'Melogale']]
nn_predwv = create_index(pred_wv)
en_vecd = pickle.load(open(TRANS_PATH / 'wiki.en.pkl', 'rb'))
def text2img(vec):
"""
Pull images who's vector is close to our input vector (vec)
"""
# get indices and distances
idxs, dists = get_knn(nn_predwv, vec)
im_res = [open_image(PATH / md.val_ds.fnames[i]) for i in idxs[:3]]
show_imgs(im_res, 3, figsize=(9, 3))
# en_vecd is of type dict. i.e { 'sink': 300-dim word vector }
vec = en_vecd['boat'] # get the vector for boat
text2img(vec) # pull images who's vector is close to our 'boat' vector
vec = (en_vecd['engine'] + en_vecd['boat']) / 2
text2img(vec)
vec = (en_vecd['sail'] + en_vecd['boat']) / 2
text2img(vec)
fname = 'train/n01440764/n01440764_9649.JPEG'
img = open_image(PATH / fname)
show_img(img)
<matplotlib.axes._subplots.AxesSubplot at 0x7ef974b2f860>
t_img = md.val_ds.transform(img)
pred = learn.predict_array(t_img[None]) # t_img[None] give us batch size of 1
idxs, dists = get_knn(nn_predwv, pred) # pred is word vector, not probs
im_res = [open_image(PATH / md.val_ds.fnames[i]) for i in idxs[1:4]] # note we are getting idxs from 1 to 4, not 0 to 3.
show_imgs(im_res, 3, figsize=(9,3))