%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
from fastai.dataset import *
from pathlib import Path
import json
torch.cuda.set_device(1)
PATH = Path('data/carvana')
list(PATH.iterdir())
[PosixPath('data/carvana/train_masks.csv'), PosixPath('data/carvana/train_masks-128'), PosixPath('data/carvana/sample_submission.csv'), PosixPath('data/carvana/train_masks_png'), PosixPath('data/carvana/train.csv'), PosixPath('data/carvana/train-128'), PosixPath('data/carvana/train'), PosixPath('data/carvana/metadata.csv'), PosixPath('data/carvana/tmp'), PosixPath('data/carvana/models'), PosixPath('data/carvana/train_masks')]
MASKS_FN = 'train_masks.csv'
META_FN = 'metadata.csv'
TRAIN_DN = 'train'
MASKS_DN = 'train_masks'
masks_csv = pd.read_csv(PATH/MASKS_FN)
masks_csv.head()
img | rle_mask | |
---|---|---|
0 | 00087a6bd4dc_01.jpg | 879386 40 881253 141 883140 205 885009 17 8850... |
1 | 00087a6bd4dc_02.jpg | 873779 4 875695 7 877612 9 879528 12 881267 15... |
2 | 00087a6bd4dc_03.jpg | 864300 9 866217 13 868134 15 870051 16 871969 ... |
3 | 00087a6bd4dc_04.jpg | 879735 20 881650 26 883315 92 883564 30 885208... |
4 | 00087a6bd4dc_05.jpg | 883365 74 883638 28 885262 119 885550 34 88716... |
meta_csv = pd.read_csv(PATH/META_FN)
meta_csv.head()
id | year | make | model | trim1 | trim2 | |
---|---|---|---|---|---|---|
0 | 0004d4463b50 | 2014.0 | Acura | TL | TL | w/SE |
1 | 00087a6bd4dc | 2014.0 | Acura | RLX | RLX | w/Tech |
2 | 000aa097d423 | 2012.0 | Mazda | MAZDA6 | MAZDA6 | i Sport |
3 | 000f19f6e7d4 | 2016.0 | Chevrolet | Camaro | Camaro | SS |
4 | 00144e887ae9 | 2015.0 | Acura | TLX | TLX | SH-AWD V6 w/Advance Pkg |
def show_img(im, figsize=None, ax=None, alpha=None):
if not ax: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(im, alpha=alpha)
ax.set_axis_off()
return ax
CAR_ID = '00087a6bd4dc'
list((PATH/TRAIN_DN).iterdir())[:5]
[PosixPath('data/carvana/train/5ab34f0e3ea5_15.jpg'), PosixPath('data/carvana/train/de3ca5ec1e59_07.jpg'), PosixPath('data/carvana/train/28d9a149cb02_13.jpg'), PosixPath('data/carvana/train/36a3f7f77e85_12.jpg'), PosixPath('data/carvana/train/843763f47895_08.jpg')]
Image.open(PATH/TRAIN_DN/f'{CAR_ID}_01.jpg').resize((300,200))
list((PATH/MASKS_DN).iterdir())[:5]
[PosixPath('data/carvana/train_masks/6c0cd487abcd_03_mask.gif'), PosixPath('data/carvana/train_masks/351c583eabd6_01_mask.gif'), PosixPath('data/carvana/train_masks/90fdd8932877_02_mask.gif'), PosixPath('data/carvana/train_masks/28d9a149cb02_10_mask.gif'), PosixPath('data/carvana/train_masks/88bc32b9e1d9_14_mask.gif')]
Image.open(PATH/MASKS_DN/f'{CAR_ID}_01_mask.gif').resize((300,200))
ims = [open_image(PATH/TRAIN_DN/f'{CAR_ID}_{i+1:02d}.jpg') for i in range(16)]
fig, axes = plt.subplots(4, 4, figsize=(9, 6))
for i,ax in enumerate(axes.flat): show_img(ims[i], ax=ax)
plt.tight_layout(pad=0.1)
(PATH/'train_masks_png').mkdir(exist_ok=True)
def convert_img(fn):
fn = fn.name
Image.open(PATH/'train_masks'/fn).save(PATH/'train_masks_png'/f'{fn[:-4]}.png')
files = list((PATH/'train_masks').iterdir())
with ThreadPoolExecutor(8) as e: e.map(convert_img, files)
Wall time: 27.4 s
(PATH/'train_masks-128').mkdir(exist_ok=True)
def resize_mask(fn):
Image.open(fn).resize((128,128)).save((fn.parent.parent)/'train_masks-128'/fn.name)
files = list((PATH/'train_masks_png').iterdir())
with ThreadPoolExecutor(8) as e: e.map(resize_mask, files)
(PATH/'train-128').mkdir(exist_ok=True)
def resize_img(fn):
Image.open(fn).resize((128,128)).save((fn.parent.parent)/'train-128'/fn.name)
files = list((PATH/'train').iterdir())
with ThreadPoolExecutor(8) as e: e.map(resize_img, files)
TRAIN_DN = 'train-128'
MASKS_DN = 'train_masks-128'
sz = 128
bs = 64
ims = [open_image(PATH/TRAIN_DN/f'{CAR_ID}_{i+1:02d}.jpg') for i in range(16)]
im_masks = [open_image(PATH/MASKS_DN/f'{CAR_ID}_{i+1:02d}_mask.png') for i in range(16)]
fig, axes = plt.subplots(4, 4, figsize=(9, 6))
for i,ax in enumerate(axes.flat):
ax = show_img(ims[i], ax=ax)
show_img(im_masks[i][...,0], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)
class MatchedFilesDataset(FilesDataset):
def __init__(self, fnames, y, transform, path):
self.y=y
assert(len(fnames)==len(y))
super().__init__(fnames, transform, path)
def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
def get_c(self): return 0
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])
len(x_names)//16//5*16
1008
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
len(val_x),len(trn_x)
(1008, 4080)
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
RandomFlip(tfm_y=TfmType.CLASS),
RandomLighting(0.05, 0.05)]
# aug_tfms = []
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=8, classes=None)
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)
fig, axes = plt.subplots(5, 6, figsize=(12, 10))
for i,ax in enumerate(axes.flat):
ax=show_img(x[i], ax=ax)
show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)
class Empty(nn.Module):
def forward(self,x): return x
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=Empty())
learn = ConvLearner(md, models)
learn.summary()
class StdUpsample(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
self.bn = nn.BatchNorm2d(nout)
def forward(self, x): return self.bn(F.relu(self.conv(x)))
flatten_channel = Lambda(lambda x: x[:,0])
simple_up = nn.Sequential(
nn.ReLU(),
StdUpsample(512,256),
StdUpsample(256,256),
StdUpsample(256,256),
StdUpsample(256,256),
nn.ConvTranspose2d(256, 1, 2, stride=2),
flatten_channel
)
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)]
learn.lr_find()
learn.sched.plot()
A Jupyter Widget
94%|█████████▍| 30/32 [00:05<00:00, 5.48it/s, loss=10.6]
lr=4e-2
learn.fit(lr,1,cycle_len=5,use_clr=(20,5))
A Jupyter Widget
epoch trn_loss val_loss <lambda> 0 0.124078 0.133566 0.945951 1 0.111241 0.112318 0.954912 2 0.099743 0.09817 0.957507 3 0.090651 0.092375 0.958117 4 0.084031 0.086026 0.963243
[0.086025625, 0.96324310824275017]
learn.save('tmp')
learn.load('tmp')
py,ay = learn.predict_with_targs()
ay.shape
(1008, 128, 128)
show_img(ay[0]);
show_img(py[0]>0);
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/4
learn.fit(lrs,1,cycle_len=20,use_clr=(20,10))
A Jupyter Widget
epoch trn_loss val_loss <lambda> 0 0.06577 0.053292 0.972977 1 0.049475 0.043025 0.982559 2 0.039146 0.035927 0.98337 3 0.03405 0.031903 0.986982 4 0.029788 0.029065 0.987944 5 0.027374 0.027752 0.988029 6 0.026041 0.026718 0.988226 7 0.024302 0.025927 0.989512 8 0.022921 0.026102 0.988276 9 0.021944 0.024714 0.989537 10 0.021135 0.0241 0.990628 11 0.020494 0.023367 0.990652 12 0.01988 0.022961 0.990989 13 0.019241 0.022498 0.991014 14 0.018697 0.022492 0.990571 15 0.01812 0.021771 0.99105 16 0.017597 0.02183 0.991365 17 0.017192 0.021434 0.991364 18 0.016768 0.021383 0.991643 19 0.016418 0.021114 0.99173
[0.021113895, 0.99172959849238396]
learn.save('0')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
ax = show_img(denorm(x)[0])
show_img(py[0]>0, ax=ax, alpha=0.5);
ax = show_img(denorm(x)[0])
show_img(y[0], ax=ax, alpha=0.5);
TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
sz = 512
bs = 16
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
len(val_x),len(trn_x)
(1008, 4080)
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=8, classes=None)
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i,ax in enumerate(axes.flat):
ax=show_img(x[i], ax=ax)
show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)
simple_up = nn.Sequential(
nn.ReLU(),
StdUpsample(512,256),
StdUpsample(256,256),
StdUpsample(256,256),
StdUpsample(256,256),
nn.ConvTranspose2d(256, 1, 2, stride=2),
flatten_channel
)
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)]
learn.load('0')
learn.lr_find()
learn.sched.plot()
A Jupyter Widget
85%|████████▌ | 218/255 [02:12<00:22, 1.64it/s, loss=8.91]
lr=4e-2
learn.fit(lr,1,cycle_len=5,use_clr=(20,5))
A Jupyter Widget
epoch trn_loss val_loss <lambda> 0 0.02178 0.020653 0.991708 1 0.017927 0.020653 0.990241 2 0.015958 0.016115 0.993394 3 0.015172 0.015143 0.993696 4 0.014315 0.014679 0.99388
[0.014679321, 0.99388032489352751]
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/4
learn.fit(lrs,1,cycle_len=8,use_clr=(20,8))
A Jupyter Widget
epoch trn_loss val_loss mask_acc 0 0.038687 0.018685 0.992782 1 0.024906 0.014355 0.994933 2 0.025055 0.014737 0.995526 3 0.024155 0.014083 0.995708 4 0.013446 0.010564 0.996166 5 0.01607 0.010555 0.996096 6 0.019197 0.010883 0.99621 7 0.016157 0.00998 0.996393
[0.0099797687, 0.99639255659920833]
learn.save('512')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
ax = show_img(denorm(x)[0])
show_img(py[0]>0, ax=ax, alpha=0.5);
ax = show_img(denorm(x)[0])
show_img(y[0], ax=ax, alpha=0.5);
sz = 1024
bs = 4
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=8, classes=None)
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)
y = to_np(y)
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for i,ax in enumerate(axes.flat):
show_img(x[i], ax=ax)
show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)
simple_up = nn.Sequential(
nn.ReLU(),
StdUpsample(512,256),
StdUpsample(256,256),
StdUpsample(256,256),
StdUpsample(256,256),
nn.ConvTranspose2d(256, 1, 2, stride=2),
flatten_channel,
)
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)]
learn.load('512')
learn.lr_find()
learn.sched.plot()
A Jupyter Widget
85%|████████▌ | 218/255 [02:12<00:22, 1.64it/s, loss=8.91]
lr=4e-2
learn.fit(lr,1,cycle_len=2,use_clr=(20,4))
A Jupyter Widget
epoch trn_loss val_loss <lambda> 0 0.01066 0.011119 0.996227 1 0.009357 0.009696 0.996553
[0.0096957013, 0.99655332546385511]
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/8
learn.fit(lrs,1,cycle_len=40,use_clr=(20,10))
A Jupyter Widget
epoch trn_loss val_loss mask_acc 0 0.015565 0.007449 0.997661 1 0.01979 0.008376 0.997542 2 0.014874 0.007826 0.997736 3 0.016104 0.007854 0.997347 4 0.023386 0.009745 0.997218 5 0.018972 0.008453 0.997588 6 0.013184 0.007612 0.997588 7 0.010686 0.006775 0.997688 8 0.0293 0.015299 0.995782 9 0.018713 0.00763 0.997638 10 0.015432 0.006575 0.9978 11 0.110205 0.060062 0.979043 12 0.014374 0.007753 0.997451 13 0.022286 0.010282 0.997587 14 0.015645 0.00739 0.997776 15 0.013821 0.00692 0.997869 16 0.022389 0.008632 0.997696 17 0.014607 0.00677 0.997837 18 0.018748 0.008194 0.997657 19 0.016447 0.007237 0.997899 20 0.023596 0.008211 0.997918 21 0.015721 0.00674 0.997848 22 0.01572 0.006415 0.998006 23 0.019519 0.007591 0.997876 24 0.011159 0.005998 0.998053 25 0.010291 0.005806 0.998012 26 0.010893 0.005755 0.998046 27 0.014534 0.006313 0.997901 28 0.020971 0.006855 0.998018 29 0.014074 0.006107 0.998053 30 0.01782 0.006561 0.998114 31 0.01742 0.006414 0.997942 32 0.016829 0.006514 0.9981 33 0.013148 0.005819 0.998033 34 0.023495 0.006261 0.997856 35 0.010931 0.005516 0.99812 36 0.015798 0.006176 0.998126 37 0.021636 0.005931 0.998067 38 0.012133 0.005496 0.998158 39 0.012562 0.005678 0.998172
[0.0056782686, 0.99817223208291195]
learn.save('1024')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
ax = show_img(denorm(x)[0])
show_img(py[0][0]>0, ax=ax, alpha=0.5);
ax = show_img(denorm(x)[0])
show_img(y[0,...,-1], ax=ax, alpha=0.5);
show_img(py[0][0]>0);
show_img(y[0,...,-1]);