%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
torch.backends.cudnn.benchmark=True
import fastText as ft
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfms = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
fname = 'valid/n01440764/ILSVRC2012_val_00007197.JPEG'
PATH = Path('data/imagenet/')
TMP_PATH = PATH/'tmp'
TRANS_PATH = Path('data/translate/')
PATH_TRN = PATH/'train'
img = Image.open(PATH/fname)
import fastai
fastai.dataloader.DataLoader
arch=resnet50
ttfms,vtfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
def to_array(x,y): return np.array(x).astype(np.float32)/255,None
def TT(x,y): return torch.from_numpy(x),None
ttfms.tfms = [to_array] + ttfms.tfms# + [TT]
ttfms(img)
( 0 ,.,.) = -1.2303e+00 -8.5119e-01 -5.9588e-01 ... 2.6367e-01 2.9539e-01 -6.1730e-03 -3.7076e-01 -2.6681e-01 -8.4245e-01 ... 2.7774e-02 -2.1419e-01 -3.8521e-02 -4.4917e-01 -8.2077e-01 -9.1286e-01 ... -1.8501e-01 -3.3902e-01 -2.7731e-01 ... ⋱ ... -2.3891e-01 -6.1463e-01 -7.5132e-01 ... -3.6532e-01 -3.1892e-01 -5.1175e-01 -3.5185e-01 -6.9721e-01 -9.4736e-01 ... -4.4875e-01 -4.6495e-01 -3.5350e-01 -6.0510e-01 -1.2168e+00 -7.6353e-01 ... -2.4128e-01 -3.2143e-01 -4.5569e-01 ( 1 ,.,.) = -1.0955e+00 -3.7223e-01 6.5788e-02 ... 3.6881e-01 5.9299e-01 1.9598e-01 2.8945e-01 6.3827e-01 -1.3719e-01 ... 1.1291e-01 7.1316e-02 2.3313e-01 3.5067e-01 -1.5094e-01 -3.2621e-01 ... -1.5517e-01 -9.9879e-02 1.8858e-01 ... ⋱ ... 8.9915e-02 1.1407e-01 6.0954e-02 ... 5.4866e-02 -3.2704e-02 -2.1574e-01 3.8106e-03 -3.5581e-01 -4.8731e-01 ... -1.6228e-01 -1.9376e-01 -7.6189e-02 -3.5776e-01 -1.0699e+00 -5.1589e-01 ... 3.9278e-02 -1.8401e-02 -1.4476e-01 ( 2 ,.,.) = -1.5169e+00 -1.2942e+00 -1.2443e+00 ... -2.9852e-01 -3.4353e-01 -3.4875e-01 -8.5310e-01 -7.1368e-01 -1.1537e+00 ... -3.7036e-01 -7.6493e-01 -5.2110e-01 -7.2973e-01 -1.1940e+00 -1.2884e+00 ... -6.7519e-01 -8.5094e-01 -8.4855e-01 ... ⋱ ... -1.1173e+00 -1.4027e+00 -1.5842e+00 ... -2.8738e-01 8.6017e-02 6.4270e-03 -1.1104e+00 -1.1831e+00 -1.4715e+00 ... 1.0541e-01 1.8515e-01 2.9458e-01 -1.3824e+00 -1.5386e+00 -1.4586e+00 ... 4.5452e-01 3.5802e-01 5.8026e-02 [torch.FloatTensor of size 3x224x224]
ft_vecs = ft.load_model(str((TRANS_PATH/'wiki.en.bin')))
ft_vecs.get_word_vector('king')
array([ 0.03259, -0.18164, -0.29049, -0.10506, -0.16712, -0.07748, -0.5661 , -0.08622, -0.00216, 0.15366, 0.12189, -0.14722, 0.01511, 0.07209, -0.02156, -0.20612, -0.02104, -0.01999, -0.15506, 0.00802, -0.22746, 0.33518, -0.10629, -0.50318, -0.1582 , 0.27829, 0.05752, -0.32697, 0.04766, 0.01076, 0.13972, -0.12445, -0.18989, 0.32969, -0.32513, 0.10958, 0.21962, -0.47215, 0.03422, -0.2207 , 0.02177, 0.0832 , -0.04776, -0.48873, 0.05207, -0.15001, -0.19203, 0.06177, 0.15535, -0.05598, 0.11071, 0.39161, -0.17716, 0.05449, 0.25898, -0.13954, 0.4272 , -0.07273, -0.4714 , 0.04993, 0.29526, -0.05319, 0.03451, -0.10583, -0.30137, 0.16372, 0.07541, 0.21018, -0.11459, 0.10976, 0.04923, 0.17688, 0.45658, -0.59762, -0.0039 , 0.08866, 0.53103, 0.153 , -0.1673 , 0.13121, -0.05547, -0.03582, -0.34535, 0.09128, 0.03323, 0.45211, -0.16894, 0.21139, 0.24153, 0.51014, -0.01474, -0.47179, 0.2235 , -0.34668, 0.12126, 0.23727, -0.08424, 0.04555, -0.07698, 0.0428 , -0.13887, 0.29286, -0.28864, 0.53446, 0.02677, -0.04119, 0.40156, 0.38334, 0.01935, 0.02089, 0.02142, -0.11958, -0.44997, 0.13685, -0.12185, -0.00509, 0.60342, 0.65889, -0.16251, 0.46393, 0.19732, 0.19346, -0.07765, 0.17387, 0.07279, 0.04365, -0.01246, 0.4392 , 0.03182, 0.34927, -0.13155, 0.41265, 0.1348 , 0.03162, 0.17821, 0.20899, -0.03224, -0.37799, 0.23646, 0.10512, -0.00483, 0.33617, 0.43214, 0.28264, 0.01725, 0.35155, 0.28504, -0.41468, -0.20859, 0.08935, -0.08568, -0.3982 , -0.61611, 0.574 , -0.34191, 0.03569, 0.08309, 0.02758, 0.30767, -0.14426, -0.23718, 0.19269, 0.12444, 0.20298, -0.08636, -0.30212, 0.06119, 0.08865, 0.60565, 0.23092, -0.16018, -0.44802, -0.14103, 0.08389, 0.08604, 0.17387, -0.11659, 0.15751, -0.25178, 0.12577, 0.28713, -0.00183, 0.05259, -0.0495 , -0.03082, 0.13133, -0.00867, 0.00691, 0.30406, 0.18153, -0.05479, -0.39295, 0.29229, 0.27204, 0.01185, 0.02325, 0.02535, -0.21103, -0.45489, 0.10004, 0.26659, -0.12585, -0.03636, -0.1304 , -0.10385, -0.35109, -0.04138, 0.20202, 0.08724, -0.22088, 0.25375, 0.08034, 0.0022 , -0.14621, -0.16164, 0.12694, -0.01651, -0.11299, -0.06235, 0.15739, -0.20588, -0.09687, -0.22731, -0.10299, -0.02208, 0.1705 , -0.41714, 0.13382, -0.09988, -0.35683, 0.49678, -0.00604, -0.09917, 0.28355, 0.27951, 0.09213, 0.12555, 0.12955, 0.05188, -0.14202, -0.18416, -0.48024, -0.02423, 0.10908, -0.04117, -0.20895, -0.30235, 0.47612, -0.22305, -0.41871, -0.03084, 0.02981, 0.21836, -0.04544, -0.24222, 0.0735 , -0.16438, -0.05721, 0.31028, 0.26954, 0.20621, 0.04835, 0.10146, -0.2655 , 0.00589, -0.0269 , 0.05519, 0.2096 , -0.21835, 0.12025, -0.44548, 0.05322, -0.23166, 0.03323, 0.13661, -0.39058, 0.1834 , 0.01626, -0.19765, 0.14757, -0.06413, 0.34661, 0.31601, 0.13334, -0.53255, 0.26908, 0.27234, -0.1101 , -0.11572, -0.42586, 0.21509, -0.23383, 0.07461, 0.30356, 0.0955 , -0.30532, -0.2858 , 0.27764, 0.04028, -0.09576], dtype=float32)
np.corrcoef(ft_vecs.get_word_vector('jeremy'), ft_vecs.get_word_vector('Jeremy'))
array([[1. , 0.60866], [0.60866, 1. ]])
np.corrcoef(ft_vecs.get_word_vector('banana'), ft_vecs.get_word_vector('Jeremy'))
array([[1. , 0.14482], [0.14482, 1. ]])
ft_words = ft_vecs.get_words(include_freq=True)
ft_word_dict = {k:v for k,v in zip(*ft_words)}
ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x])
len(ft_words)
2519370
from fastai.io import get_data
CLASSES_FN = 'imagenet_class_index.json'
get_data(f'http://files.fast.ai/models/{CLASSES_FN}', TMP_PATH/CLASSES_FN)
imagenet_class_index.json: 41.0kB [00:00, 56.2kB/s]
WORDS_FN = 'classids.txt'
get_data(f'http://files.fast.ai/data/{WORDS_FN}', PATH/WORDS_FN)
classids.txt: 1.74MB [00:02, 765kB/s]
class_dict = json.load((TMP_PATH/CLASSES_FN).open())
classids_1k = dict(class_dict.values())
nclass = len(class_dict); nclass
1000
class_dict['0']
['n01440764', 'tench']
classid_lines = (PATH/WORDS_FN).open().readlines()
classid_lines[:5]
['n00001740 entity\n', 'n00001930 physical_entity\n', 'n00002137 abstraction\n', 'n00002452 thing\n', 'n00002684 object\n']
classids = dict(l.strip().split() for l in classid_lines)
len(classids),len(classids_1k)
(82115, 1000)
lc_vec_d = {w.lower(): ft_vecs.get_word_vector(w) for w in ft_words[-1000000:]}
syn_wv = [(k, lc_vec_d[v.lower()]) for k,v in classids.items()
if v.lower() in lc_vec_d]
syn_wv_1k = [(k, lc_vec_d[v.lower()]) for k,v in classids_1k.items()
if v.lower() in lc_vec_d]
syn2wv = dict(syn_wv)
len(syn2wv)
49469
pickle.dump(syn2wv, (TMP_PATH/'syn2wv.pkl').open('wb'))
pickle.dump(syn_wv_1k, (TMP_PATH/'syn_wv_1k.pkl').open('wb'))
syn2wv = pickle.load((TMP_PATH/'syn2wv.pkl').open('rb'))
syn_wv_1k = pickle.load((TMP_PATH/'syn_wv_1k.pkl').open('rb'))
images = []
img_vecs = []
for d in (PATH/'train').iterdir():
if d.name not in syn2wv: continue
vec = syn2wv[d.name]
for f in d.iterdir():
images.append(str(f.relative_to(PATH)))
img_vecs.append(vec)
n_val=0
for d in (PATH/'valid').iterdir():
if d.name not in syn2wv: continue
vec = syn2wv[d.name]
for f in d.iterdir():
images.append(str(f.relative_to(PATH)))
img_vecs.append(vec)
n_val += 1
n_val
28650
img_vecs = np.stack(img_vecs)
img_vecs.shape
pickle.dump(images, (TMP_PATH/'images.pkl').open('wb'))
pickle.dump(img_vecs, (TMP_PATH/'img_vecs.pkl').open('wb'))
images = pickle.load((TMP_PATH/'images.pkl').open('rb'))
img_vecs = pickle.load((TMP_PATH/'img_vecs.pkl').open('rb'))
arch = resnet50
n = len(images); n
766876
val_idxs = list(range(n-28650, n))
tfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
md = ImageClassifierData.from_names_and_array(PATH, images, img_vecs, val_idxs=val_idxs,
classes=None, tfms=tfms, continuous=True, bs=256)
x,y = next(iter(md.val_dl))
models = ConvnetBuilder(arch, md.c, is_multi=False, is_reg=True, xtra_fc=[1024], ps=[0.2,0.2])
learn = ConvLearner(md, models, precompute=True)
learn.opt_fn = partial(optim.Adam, betas=(0.9,0.99))
/home/ubuntu/fastai/courses/dl2/fastai/initializers.py:6: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_. if hasattr(m, 'weight'): init_fn(m.weight)
def cos_loss(inp,targ): return 1 - F.cosine_similarity(inp,targ).mean()
learn.crit = cos_loss
learn.lr_find(start_lr=1e-4, end_lr=1e15)
learn.sched.plot()
lr = 1e-2
wd = 1e-7
learn.precompute=True
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20,10))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
epoch trn_loss val_loss 0 0.104692 0.125685 1 0.112455 0.129307 2 0.110631 0.126568 3 0.108629 0.127338 4 0.110791 0.125033 5 0.108859 0.125186 6 0.106582 0.123875 7 0.103227 0.123945 8 0.10396 0.12304 9 0.105898 0.124894 10 0.10498 0.122582 11 0.104983 0.122906 12 0.102317 0.121171 13 0.10017 0.121816 14 0.099454 0.119647 15 0.100425 0.120914 16 0.097226 0.119724 17 0.094666 0.118746 18 0.094137 0.118744 19 0.090076 0.117908
[0.11790786389489033]
learn.bn_freeze(True)
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20,10))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
epoch trn_loss val_loss 0 0.104692 0.125685 1 0.112455 0.129307 2 0.110631 0.126568 3 0.108629 0.127338 4 0.110791 0.125033 5 0.108859 0.125186 6 0.106582 0.123875 7 0.103227 0.123945 8 0.10396 0.12304 9 0.105898 0.124894 10 0.10498 0.122582 11 0.104983 0.122906 12 0.102317 0.121171 13 0.10017 0.121816 14 0.099454 0.119647 15 0.100425 0.120914 16 0.097226 0.119724 17 0.094666 0.118746 18 0.094137 0.118744 19 0.090076 0.117908
[0.11790786389489033]
lrs = np.array([lr/1000,lr/100,lr])
learn.precompute=False
learn.freeze_to(1)
learn.save('pre0')
learn.load('pre0')
syns, wvs = list(zip(*syn_wv_1k))
wvs = np.array(wvs)
%time pred_wv = learn.predict()
CPU times: user 18.4 s, sys: 7.91 s, total: 26.3 s Wall time: 7.17 s
start=300
denorm = md.val_ds.denorm
def show_img(im, figsize=None, ax=None):
if not ax: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(im)
ax.axis('off')
return ax
def show_imgs(ims, cols, figsize=None):
fig,axes = plt.subplots(len(ims)//cols, cols, figsize=figsize)
for i,ax in enumerate(axes.flat): show_img(ims[i], ax=ax)
plt.tight_layout()
show_imgs(denorm(md.val_ds[start:start+25][0]), 5, (10,10))
import nmslib
def create_index(a):
index = nmslib.init(space='angulardist')
index.addDataPointBatch(a)
index.createIndex()
return index
def get_knns(index, vecs):
return zip(*index.knnQueryBatch(vecs, k=10, num_threads=4))
def get_knn(index, vec): return index.knnQuery(vec, k=10)
nn_wvs = create_index(wvs)
idxs,dists = get_knns(nn_wvs, pred_wv)
[[classids[syns[id]] for id in ids[:3]] for ids in idxs[start:start+10]]
[['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['spoonbill', 'bustard', 'oystercatcher'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill']]
all_syns, all_wvs = list(zip(*syn2wv.items()))
all_wvs = np.array(all_wvs)
nn_allwvs = create_index(all_wvs)
idxs,dists = get_knns(nn_allwvs, pred_wv)
[[classids[all_syns[id]] for id in ids[:3]] for ids in idxs[start:start+10]]
[['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['spoonbill', 'bustard', 'oystercatcher'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill'], ['limpkin', 'oystercatcher', 'spoonbill']]
nn_predwv = create_index(pred_wv)
en_vecd = pickle.load(open(TRANS_PATH/'wiki.en.pkl','rb'))
vec = en_vecd['boat']
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
vec = (en_vecd['engine'] + en_vecd['boat'])/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
vec = (en_vecd['sail'] + en_vecd['boat'])/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
fname = 'valid/n01440764/ILSVRC2012_val_00007197.JPEG'
img = open_image(PATH/fname)
show_img(img);
t_img = md.val_ds.transform(img)
pred = learn.predict_array(t_img[None])
idxs,dists = get_knn(nn_predwv, pred)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[1:4]], 3, figsize=(9,3));