%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
from fastai.dataset import *
from pathlib import Path
import json
from PIL import ImageDraw, ImageFont
from matplotlib import patches, patheffects
torch.cuda.set_device(3)
We will be looking at the Pascal VOC dataset. It's quite slow, so you may prefer to download from this mirror. There are two different competition/research datasets, from 2007 and 2012. We'll be using the 2007 version. You can use the larger 2012 for better results, or even combine them (but be careful to avoid data leakage between the validation sets if you do this).
Unlike previous lessons, we are using the python 3 standard library pathlib
for our paths and file access. Note that it returns an OS-specific class (on Linux, PosixPath
) so your output may look a little different. Most libraries than take paths as input can take a pathlib object - although some (like cv2
) can't, in which case you can use str()
to convert it to a string.
PATH = Path('data/pascal')
list(PATH.iterdir())
[PosixPath('data/pascal/pascal_train2007.json'), PosixPath('data/pascal/pascal_test2007.json'), PosixPath('data/pascal/pascal_val2012.json'), PosixPath('data/pascal/VOCtrainval_06-Nov-2007.tar'), PosixPath('data/pascal/VOCtrainval_11-May-2012.tar'), PosixPath('data/pascal/VOCdevkit'), PosixPath('data/pascal/pascal_val2007.json'), PosixPath('data/pascal/tmp'), PosixPath('data/pascal/models'), PosixPath('data/pascal/src'), PosixPath('data/pascal/pascal_train2012.json')]
As well as the images, there are also annotations - bounding boxes showing where each object is. These were hand labeled. The original version were in XML, which is a little hard to work with nowadays, so we uses the more recent JSON version which you can download from this link.
You can see here how pathlib
includes the ability to open files (amongst many other capabilities).
trn_j = json.load((PATH/'pascal_train2007.json').open())
trn_j.keys()
dict_keys(['images', 'type', 'annotations', 'categories'])
IMAGES,ANNOTATIONS,CATEGORIES = ['images', 'annotations', 'categories']
trn_j[IMAGES][:5]
[{'file_name': '000012.jpg', 'height': 333, 'id': 12, 'width': 500}, {'file_name': '000017.jpg', 'height': 364, 'id': 17, 'width': 480}, {'file_name': '000023.jpg', 'height': 500, 'id': 23, 'width': 334}, {'file_name': '000026.jpg', 'height': 333, 'id': 26, 'width': 500}, {'file_name': '000032.jpg', 'height': 281, 'id': 32, 'width': 500}]
trn_j[ANNOTATIONS][:2]
[{'area': 34104, 'bbox': [155, 96, 196, 174], 'category_id': 7, 'id': 1, 'ignore': 0, 'image_id': 12, 'iscrowd': 0, 'segmentation': [[155, 96, 155, 270, 351, 270, 351, 96]]}, {'area': 13110, 'bbox': [184, 61, 95, 138], 'category_id': 15, 'id': 2, 'ignore': 0, 'image_id': 17, 'iscrowd': 0, 'segmentation': [[184, 61, 184, 199, 279, 199, 279, 61]]}]
trn_j[CATEGORIES][:4]
[{'id': 1, 'name': 'aeroplane', 'supercategory': 'none'}, {'id': 2, 'name': 'bicycle', 'supercategory': 'none'}, {'id': 3, 'name': 'bird', 'supercategory': 'none'}, {'id': 4, 'name': 'boat', 'supercategory': 'none'}]
It's helpful to use constants instead of strings, since we get tab-completion and don't mistype.
FILE_NAME,ID,IMG_ID,CAT_ID,BBOX = 'file_name','id','image_id','category_id','bbox'
cats = {o[ID]:o['name'] for o in trn_j[CATEGORIES]}
trn_fns = {o[ID]:o[FILE_NAME] for o in trn_j[IMAGES]}
trn_ids = [o[ID] for o in trn_j[IMAGES]]
list((PATH/'VOCdevkit'/'VOC2007').iterdir())
[PosixPath('data/pascal/VOCdevkit/VOC2007/SegmentationClass'), PosixPath('data/pascal/VOCdevkit/VOC2007/Annotations'), PosixPath('data/pascal/VOCdevkit/VOC2007/SegmentationObject'), PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages'), PosixPath('data/pascal/VOCdevkit/VOC2007/ImageSets')]
JPEGS = 'VOCdevkit/VOC2007/JPEGImages'
IMG_PATH = PATH/JPEGS
list(IMG_PATH.iterdir())[:5]
[PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages/005475.jpg'), PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages/001898.jpg'), PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages/006004.jpg'), PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages/006660.jpg'), PosixPath('data/pascal/VOCdevkit/VOC2007/JPEGImages/005067.jpg')]
Each image has a unique ID.
im0_d = trn_j[IMAGES][0]
im0_d[FILE_NAME],im0_d[ID]
('000012.jpg', 12)
A defaultdict
is useful any time you want to have a default dictionary entry for new keys. Here we create a dict from image IDs to a list of annotations (tuple of bounding box and class id).
We convert VOC's height/width into top-left/bottom-right, and switch x/y coords to be consistent with numpy.
def hw_bb(bb): return np.array([bb[1], bb[0], bb[3]+bb[1]-1, bb[2]+bb[0]-1])
trn_anno = collections.defaultdict(lambda:[])
for o in trn_j[ANNOTATIONS]:
if not o['ignore']:
bb = o[BBOX]
bb = hw_bb(bb)
trn_anno[o[IMG_ID]].append((bb,o[CAT_ID]))
len(trn_anno)
2501
im_a = trn_anno[im0_d[ID]]; im_a
[(array([ 96, 155, 269, 350]), 7)]
im0_a = im_a[0]; im0_a
(array([ 96, 155, 269, 350]), 7)
cats[7]
'car'
trn_anno[17]
[(array([ 61, 184, 198, 278]), 15), (array([ 77, 89, 335, 402]), 13)]
cats[15],cats[13]
('person', 'horse')
Some libs take VOC format bounding boxes, so this let's us convert back when required:
bb_voc = [155, 96, 196, 174]
bb_fastai = hw_bb(bb_voc)
def bb_hw(a): return np.array([a[1],a[0],a[3]-a[1]+1,a[2]-a[0]+1])
f'expected: {bb_voc}, actual: {bb_hw(bb_fastai)}'
'expected: [155, 96, 196, 174], actual: [155 96 196 174]'
You can use Visual Studio Code (vscode - open source editor that comes with recent versions of Anaconda, or can be installed separately), or most editors and IDEs, to find out all about the open_image
function. vscode things to know:
im = open_image(IMG_PATH/im0_d[FILE_NAME])
Matplotlib's plt.subplots
is a really useful wrapper for creating plots, regardless of whether you have more than one subplot. Note that Matplotlib has an optional object-oriented API which I think is much easier to understand and use (although few examples online use it!)
def show_img(im, figsize=None, ax=None):
if not ax: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(im)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
return ax
A simple but rarely used trick to making text visible regardless of background is to use white text with black outline, or visa versa. Here's how to do it in matplotlib.
def draw_outline(o, lw):
o.set_path_effects([patheffects.Stroke(
linewidth=lw, foreground='black'), patheffects.Normal()])
Note that *
in argument lists is the splat operator. In this case it's a little shortcut compared to writing out b[-2],b[-1]
.
def draw_rect(ax, b):
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor='white', lw=2))
draw_outline(patch, 4)
def draw_text(ax, xy, txt, sz=14):
text = ax.text(*xy, txt,
verticalalignment='top', color='white', fontsize=sz, weight='bold')
draw_outline(text, 1)
ax = show_img(im)
b = bb_hw(im0_a[0])
draw_rect(ax, b)
draw_text(ax, b[:2], cats[im0_a[1]])
def draw_im(im, ann):
ax = show_img(im, figsize=(16,8))
for b,c in ann:
b = bb_hw(b)
draw_rect(ax, b)
draw_text(ax, b[:2], cats[c], sz=16)
def draw_idx(i):
im_a = trn_anno[i]
im = open_image(IMG_PATH/trn_fns[i])
print(im.shape)
draw_im(im, im_a)
draw_idx(17)
(364, 480, 3)
A lambda function is simply a way to define an anonymous function inline. Here we use it to describe how to sort the annotation for each image - by bounding box size (descending).
def get_lrg(b):
if not b: raise Exception()
b = sorted(b, key=lambda x: np.product(x[0][-2:]-x[0][:2]), reverse=True)
return b[0]
trn_lrg_anno = {a: get_lrg(b) for a,b in trn_anno.items()}
Now we have a dictionary from image id to a single bounding box - the largest for that image.
b,c = trn_lrg_anno[23]
b = bb_hw(b)
ax = show_img(open_image(IMG_PATH/trn_fns[23]), figsize=(5,10))
draw_rect(ax, b)
draw_text(ax, b[:2], cats[c], sz=16)
(PATH/'tmp').mkdir(exist_ok=True)
CSV = PATH/'tmp/lrg.csv'
Often it's easiest to simply create a CSV of the data you want to model, rather than trying to create a custom dataset. Here we use Pandas to help us create a CSV of the image filename and class.
df = pd.DataFrame({'fn': [trn_fns[o] for o in trn_ids],
'cat': [cats[trn_lrg_anno[o][1]] for o in trn_ids]}, columns=['fn','cat'])
df.to_csv(CSV, index=False)
f_model = resnet34
sz=224
bs=64
From here it's just like Dogs vs Cats!
tfms = tfms_from_model(f_model, sz, aug_tfms=transforms_side_on, crop_type=CropType.NO)
md = ImageClassifierData.from_csv(PATH, JPEGS, CSV, tfms=tfms, bs=bs)
x,y=next(iter(md.val_dl))
show_img(md.val_ds.denorm(to_np(x))[0]);
learn = ConvLearner.pretrained(f_model, md, metrics=[accuracy])
learn.opt_fn = optim.Adam
lrf=learn.lr_find(1e-5,100)
When you LR finder graph looks like this, you can ask for more points on each end:
learn.sched.plot()
learn.sched.plot(n_skip=5, n_skip_end=1)
lr = 2e-2
learn.fit(lr, 1, cycle_len=1)
A Jupyter Widget
epoch trn_loss val_loss accuracy 0 1.335532 0.6443 0.804838
[0.6443001, 0.80483774095773697]
lrs = np.array([lr/1000,lr/100,lr])
learn.freeze_to(-2)
lrf=learn.lr_find(lrs/1000)
learn.sched.plot(1)
A Jupyter Widget
84%|████████▍ | 27/32 [00:07<00:01, 3.76it/s, loss=4.99]
learn.fit(lrs/5, 1, cycle_len=1)
A Jupyter Widget
epoch trn_loss val_loss accuracy 0 0.780925 0.575539 0.821064
[0.57553864, 0.82106370478868484]
learn.unfreeze()
Accuracy isn't improving much - since many images have multiple different objects, it's going to be impossible to be that accurate.
learn.fit(lrs/5, 1, cycle_len=2)
A Jupyter Widget
epoch trn_loss val_loss accuracy 0 0.609306 0.570568 0.821514 1 0.462856 0.574303 0.8128
[0.57430345, 0.81280048191547394]
learn.save('clas_one')
learn.load('clas_one')
x,y = next(iter(md.val_dl))
probs = F.softmax(predict_batch(learn.model, x), -1)
x,preds = to_np(x),to_np(probs)
preds = np.argmax(preds, -1)
You can use the python debugger pdb
to step through code.
pdb.set_trace()
to set a breakpoint%debug
magic to trace an errorCommands you need to know:
fig, axes = plt.subplots(3, 4, figsize=(12, 8))
for i,ax in enumerate(axes.flat):
ima=md.val_ds.denorm(x)[i]
b = md.classes[preds[i]]
ax = show_img(ima, ax=ax)
draw_text(ax, (0,0), b)
plt.tight_layout()
It's doing a pretty good job of classifying the largest object!
Now we'll try to find the bounding box of the largest object. This is simply a regression with 4 outputs. So we can use a CSV with multiple 'labels'.
BB_CSV = PATH/'tmp/bb.csv'
bb = np.array([trn_lrg_anno[o][0] for o in trn_ids])
bbs = [' '.join(str(p) for p in o) for o in bb]
df = pd.DataFrame({'fn': [trn_fns[o] for o in trn_ids], 'bbox': bbs}, columns=['fn','bbox'])
df.to_csv(BB_CSV, index=False)
BB_CSV.open().readlines()[:5]
['fn,bbox\n', '000012.jpg,96 155 269 350\n', '000017.jpg,77 89 335 402\n', '000023.jpg,1 2 461 242\n', '000026.jpg,124 89 211 336\n']
f_model=resnet34
sz=224
bs=64
Set continuous=True
to tell fastai this is a regression problem, which means it won't one-hot encode the labels, and will use MSE as the default crit.
Note that we have to tell the transforms constructor that our labels are coordinates, so that it can handle the transforms correctly.
Also, we use CropType.NO because we want to 'squish' the rectangular images into squares, rather than center cropping, so that we don't accidentally crop out some of the objects. (This is less of an issue in something like imagenet, where there is a single object to classify, and it's generally large and centrally located).
augs = [RandomFlip(),
RandomRotate(30),
RandomLighting(0.1,0.1)]
tfms = tfms_from_model(f_model, sz, crop_type=CropType.NO, aug_tfms=augs)
md = ImageClassifierData.from_csv(PATH, JPEGS, BB_CSV, tfms=tfms, continuous=True, bs=4)
idx=3
fig,axes = plt.subplots(3,3, figsize=(9,9))
for i,ax in enumerate(axes.flat):
x,y=next(iter(md.aug_dl))
ima=md.val_ds.denorm(to_np(x))[idx]
b = bb_hw(to_np(y[idx]))
print(b)
show_img(ima, ax=ax)
draw_rect(ax, b)
[ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.] [ 115. 63. 240. 311.]
augs = [RandomFlip(tfm_y=TfmType.COORD),
RandomRotate(30, tfm_y=TfmType.COORD),
RandomLighting(0.1,0.1, tfm_y=TfmType.COORD)]
tfms = tfms_from_model(f_model, sz, crop_type=CropType.NO, tfm_y=TfmType.COORD, aug_tfms=augs)
md = ImageClassifierData.from_csv(PATH, JPEGS, BB_CSV, tfms=tfms, continuous=True, bs=4)
idx=3
fig,axes = plt.subplots(3,3, figsize=(9,9))
for i,ax in enumerate(axes.flat):
x,y=next(iter(md.aug_dl))
ima=md.val_ds.denorm(to_np(x))[idx]
b = bb_hw(to_np(y[idx]))
print(b)
show_img(ima, ax=ax)
draw_rect(ax, b)
[ 48. 34. 112. 188.] [ 65. 36. 107. 185.] [ 49. 27. 131. 195.] [ 24. 18. 147. 204.] [ 61. 34. 113. 188.] [ 55. 31. 121. 191.] [ 52. 19. 144. 203.] [ 7. 0. 193. 222.] [ 52. 38. 105. 182.]
tfm_y = TfmType.COORD
augs = [RandomFlip(tfm_y=tfm_y),
RandomRotate(3, p=0.5, tfm_y=tfm_y),
RandomLighting(0.05,0.05, tfm_y=tfm_y)]
tfms = tfms_from_model(f_model, sz, crop_type=CropType.NO, tfm_y=tfm_y, aug_tfms=augs)
md = ImageClassifierData.from_csv(PATH, JPEGS, BB_CSV, tfms=tfms, bs=bs, continuous=True)
fastai let's you use a custom_head
to add your own module on top of a convnet, instead of the adaptive pooling and fully connected net which is added by default. In this case, we don't want to do any pooling, since we need to know the activations of each grid cell.
The final layer has 4 activations, one per bounding box coordinate. Our target is continuous, not categorical, so the MSE loss function used does not do any sigmoid or softmax to the module outputs.
512*7*7
25088
head_reg4 = nn.Sequential(Flatten(), nn.Linear(25088,4))
learn = ConvLearner.pretrained(f_model, md, custom_head=head_reg4)
learn.opt_fn = optim.Adam
learn.crit = nn.L1Loss()
learn.summary()
OrderedDict([('Conv2d-1', OrderedDict([('input_shape', [-1, 3, 224, 224]), ('output_shape', [-1, 64, 112, 112]), ('trainable', False), ('nb_params', 9408)])), ('BatchNorm2d-2', OrderedDict([('input_shape', [-1, 64, 112, 112]), ('output_shape', [-1, 64, 112, 112]), ('trainable', False), ('nb_params', 128)])), ('ReLU-3', OrderedDict([('input_shape', [-1, 64, 112, 112]), ('output_shape', [-1, 64, 112, 112]), ('nb_params', 0)])), ('MaxPool2d-4', OrderedDict([('input_shape', [-1, 64, 112, 112]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-5', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-6', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-7', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-8', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-9', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-10', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('BasicBlock-11', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-12', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-13', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-14', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-15', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-16', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-17', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('BasicBlock-18', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-19', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-20', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-21', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-22', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 36864)])), ('BatchNorm2d-23', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('trainable', False), ('nb_params', 128)])), ('ReLU-24', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('BasicBlock-25', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 64, 56, 56]), ('nb_params', 0)])), ('Conv2d-26', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 73728)])), ('BatchNorm2d-27', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-28', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-29', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-30', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('Conv2d-31', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 8192)])), ('BatchNorm2d-32', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-33', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('BasicBlock-34', OrderedDict([('input_shape', [-1, 64, 56, 56]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-35', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-36', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-37', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-38', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-39', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-40', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('BasicBlock-41', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-42', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-43', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-44', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-45', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-46', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-47', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('BasicBlock-48', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-49', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-50', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-51', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-52', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 147456)])), ('BatchNorm2d-53', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('trainable', False), ('nb_params', 256)])), ('ReLU-54', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('BasicBlock-55', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 128, 28, 28]), ('nb_params', 0)])), ('Conv2d-56', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 294912)])), ('BatchNorm2d-57', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-58', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-59', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-60', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('Conv2d-61', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 32768)])), ('BatchNorm2d-62', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-63', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-64', OrderedDict([('input_shape', [-1, 128, 28, 28]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-65', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-66', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-67', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-68', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-69', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-70', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-71', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-72', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-73', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-74', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-75', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-76', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-77', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-78', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-79', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-80', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-81', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-82', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-83', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-84', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-85', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-86', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-87', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-88', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-89', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-90', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-91', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-92', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-93', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-94', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-95', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-96', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 589824)])), ('BatchNorm2d-97', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('trainable', False), ('nb_params', 512)])), ('ReLU-98', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('BasicBlock-99', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 256, 14, 14]), ('nb_params', 0)])), ('Conv2d-100', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1179648)])), ('BatchNorm2d-101', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-102', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Conv2d-103', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 2359296)])), ('BatchNorm2d-104', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('Conv2d-105', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 131072)])), ('BatchNorm2d-106', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-107', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('BasicBlock-108', OrderedDict([('input_shape', [-1, 256, 14, 14]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Conv2d-109', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 2359296)])), ('BatchNorm2d-110', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-111', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Conv2d-112', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 2359296)])), ('BatchNorm2d-113', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-114', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('BasicBlock-115', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Conv2d-116', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 2359296)])), ('BatchNorm2d-117', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-118', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Conv2d-119', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 2359296)])), ('BatchNorm2d-120', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('trainable', False), ('nb_params', 1024)])), ('ReLU-121', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('BasicBlock-122', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 512, 7, 7]), ('nb_params', 0)])), ('Flatten-123', OrderedDict([('input_shape', [-1, 512, 7, 7]), ('output_shape', [-1, 25088]), ('nb_params', 0)])), ('Linear-124', OrderedDict([('input_shape', [-1, 25088]), ('output_shape', [-1, 4]), ('trainable', True), ('nb_params', 100356)]))])
learn.lr_find(1e-5,100)
learn.sched.plot(5)
A Jupyter Widget
78%|███████▊ | 25/32 [00:04<00:01, 6.16it/s, loss=395]
lr = 2e-3
learn.fit(lr, 2, cycle_len=1, cycle_mult=2)
A Jupyter Widget
epoch trn_loss val_loss 0 49.523444 34.764141 1 36.864003 28.007317 2 30.925234 27.230705
[27.230705]
lrs = np.array([lr/100,lr/10,lr])
learn.freeze_to(-2)
lrf=learn.lr_find(lrs/1000)
learn.sched.plot(1)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss 0 102.406115 91141120000.0
learn.fit(lrs, 2, cycle_len=1, cycle_mult=2)
A Jupyter Widget
epoch trn_loss val_loss 0 25.616161 22.83597 1 21.812624 21.387115 2 17.867176 20.335539
[20.335539]
learn.freeze_to(-3)
learn.fit(lrs, 1, cycle_len=2)
A Jupyter Widget
epoch trn_loss val_loss 0 16.571885 20.948696 1 15.072718 19.925312
[19.925312]
learn.save('reg4')
learn.load('reg4')
x,y = next(iter(md.val_dl))
learn.model.eval()
preds = to_np(learn.model(VV(x)))
fig, axes = plt.subplots(3, 4, figsize=(12, 8))
for i,ax in enumerate(axes.flat):
ima=md.val_ds.denorm(to_np(x))[i]
b = bb_hw(preds[i])
ax = show_img(ima, ax=ax)
draw_rect(ax, b)
plt.tight_layout()
f_model=resnet34
sz=224
bs=64
val_idxs = get_cv_idxs(len(trn_fns))
tfms = tfms_from_model(f_model, sz, crop_type=CropType.NO, tfm_y=TfmType.COORD, aug_tfms=augs)
md = ImageClassifierData.from_csv(PATH, JPEGS, BB_CSV, tfms=tfms,
bs=bs, continuous=True, val_idxs=val_idxs)
md2 = ImageClassifierData.from_csv(PATH, JPEGS, CSV, tfms=tfms_from_model(f_model, sz))
A dataset can be anything with __len__
and __getitem__
. Here's a dataset that adds a 2nd label to an existing dataset:
class ConcatLblDataset(Dataset):
def __init__(self, ds, y2): self.ds,self.y2 = ds,y2
def __len__(self): return len(self.ds)
def __getitem__(self, i):
x,y = self.ds[i]
return (x, (y,self.y2[i]))
We'll use it to add the classes to the bounding boxes labels.
trn_ds2 = ConcatLblDataset(md.trn_ds, md2.trn_y)
val_ds2 = ConcatLblDataset(md.val_ds, md2.val_y)
val_ds2[0][1]
(array([ 0., 49., 205., 180.], dtype=float32), 14)
We can replace the dataloaders' datasets with these new ones.
md.trn_dl.dataset = trn_ds2
md.val_dl.dataset = val_ds2
We have to denorm
alize the images from the dataloader before they can be plotted.
x,y=next(iter(md.val_dl))
idx=3
ima=md.val_ds.ds.denorm(to_np(x))[idx]
b = bb_hw(to_np(y[0][idx])); b
array([ 52., 38., 106., 184.], dtype=float32)
ax = show_img(ima)
draw_rect(ax, b)
draw_text(ax, b[:2], md2.classes[y[1][idx]])
We need one output activation for each class (for its probability) plus one for each bounding box coordinate. We'll use an extra linear layer this time, plus some dropout, to help us train a more flexible model.
head_reg4 = nn.Sequential(
Flatten(),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(25088,256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.5),
nn.Linear(256,4+len(cats)),
)
models = ConvnetBuilder(f_model, 0, 0, 0, custom_head=head_reg4)
learn = ConvLearner(md, models)
learn.opt_fn = optim.Adam
def detn_loss(input, target):
bb_t,c_t = target
bb_i,c_i = input[:, :4], input[:, 4:]
bb_i = F.sigmoid(bb_i)*224
# I looked at these quantities separately first then picked a multiplier
# to make them approximately equal
return F.l1_loss(bb_i, bb_t) + F.cross_entropy(c_i, c_t)*20
def detn_l1(input, target):
bb_t,_ = target
bb_i = input[:, :4]
bb_i = F.sigmoid(bb_i)*224
return F.l1_loss(V(bb_i),V(bb_t)).data
def detn_acc(input, target):
_,c_t = target
c_i = input[:, 4:]
return accuracy(c_i, c_t)
learn.crit = detn_loss
learn.metrics = [detn_acc, detn_l1]
learn.lr_find()
learn.sched.plot()
A Jupyter Widget
97%|█████████▋| 31/32 [00:07<00:00, 4.32it/s, loss=920]
lr=1e-2
learn.fit(lr, 1, cycle_len=3, use_clr=(32,5))
A Jupyter Widget
epoch trn_loss val_loss detn_acc detn_l1 0 72.036466 45.186367 0.802133 32.647586 1 51.037587 36.34964 0.828425 25.389733 2 41.4235 35.292709 0.835637 24.343577
[35.292709, 0.83563701808452606, 24.343576669692993]
learn.save('reg1_0')
learn.freeze_to(-2)
lrs = np.array([lr/100, lr/10, lr])
learn.lr_find(lrs/1000)
learn.sched.plot(0)
A Jupyter Widget
91%|█████████ | 29/32 [00:09<00:01, 2.99it/s, loss=308]
learn.fit(lrs/5, 1, cycle_len=5, use_clr=(32,10))
A Jupyter Widget
epoch trn_loss val_loss detn_acc detn_l1 0 34.448113 35.972973 0.801683 22.918499 1 28.889909 33.010857 0.830379 21.689888 2 24.237017 30.977512 0.81881 20.817996 3 21.132993 30.60677 0.83143 20.138552 4 18.622983 30.54178 0.825571 19.832196
[30.54178, 0.82557091116905212, 19.832195997238159]
learn.save('reg1_1')
learn.load('reg1_1')
learn.unfreeze()
learn.fit(lrs/10, 1, cycle_len=10, use_clr=(32,10))
A Jupyter Widget
epoch trn_loss val_loss detn_acc detn_l1 0 15.957164 31.111507 0.811448 19.970753 1 15.955259 32.597153 0.81235 20.111022 2 15.648723 32.231941 0.804087 19.522853 3 14.876172 30.93821 0.815805 19.226574 4 14.113872 31.03952 0.808594 19.155093 5 13.293885 29.736671 0.826022 18.761728 6 12.562566 30.000023 0.827524 18.82006 7 11.885125 30.28841 0.82512 18.904158 8 11.498326 30.070133 0.819712 18.635296 9 11.015841 30.213772 0.815805 18.551489
[30.213772, 0.81580528616905212, 18.551488876342773]
learn.save('reg1')
learn.load('reg1')
y = learn.predict()
x,_ = next(iter(md.val_dl))
from scipy.special import expit
fig, axes = plt.subplots(3, 4, figsize=(12, 8))
for i,ax in enumerate(axes.flat):
ima=md.val_ds.ds.denorm(to_np(x))[i]
bb = expit(y[i][:4])*224
b = bb_hw(bb)
c = np.argmax(y[i][4:])
ax = show_img(ima, ax=ax)
draw_rect(ax, b)
draw_text(ax, b[:2], md2.classes[c])
plt.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).