%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
PATH = "data/cifar10/"
os.makedirs(PATH,exist_ok=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))
def get_data(sz,bs):
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)
bs=128
data = get_data(32,4)
x,y=next(iter(data.trn_dl))
plt.imshow(data.trn_ds.denorm(x)[0]);
plt.imshow(data.trn_ds.denorm(x)[1]);
from fastai.models.cifar10.resnext import resnext29_8_64
m = resnext29_8_64()
bm = BasicModel(m.cuda(), name='cifar10_rn29_8_64')
data = get_data(8,bs*4)
learn = ConvLearner(data, bm)
learn.unfreeze()
lr=1e-2; wd=5e-4
learn.lr_find()
A Jupyter Widget
[ 0. 2.41622 23356282. 0.09889]
learn.sched.plot()
%time learn.fit(lr, 1)
A Jupyter Widget
[ 0. 1.69719 1.58988 0.43418] CPU times: user 33.3 s, sys: 7.2 s, total: 40.5 s Wall time: 40.8 s
learn.fit(lr, 2, cycle_len=1)
A Jupyter Widget
[ 0. 1.43256 1.39725 0.49746] [ 1. 1.35969 1.3411 0.51602]
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)
A Jupyter Widget
[ 0. 1.31041 1.29344 0.53174] [ 1. 1.30313 1.31292 0.53418] [ 2. 1.15682 1.22019 0.5668 ] [ 3. 1.26632 1.34606 0.54121] [ 4. 1.14698 1.18958 0.57598] [ 5. 1.02205 1.13905 0.60254] [ 6. 0.93291 1.13761 0.60596]
learn.save('8x8_8')
learn.load('8x8_8')
learn.set_data(get_data(16,bs*2))
%time learn.fit(1e-3, 1, wds=wd)
A Jupyter Widget
[ 0. 5.14421 5.90379 0.28985]
learn.unfreeze()
learn.lr_find()
A Jupyter Widget
[ 0. 2.28283 966.72705 0.09523]
learn.sched.plot()
lr=1e-2
learn.fit(lr, 2, cycle_len=1, wds=wd)
A Jupyter Widget
[ 0. 1.38133 1.36374 0.51137] [ 1. 1.13313 1.14011 0.60384]
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)
A Jupyter Widget
[ 0. 0.98926 0.98239 0.6605 ] [ 1. 0.9254 0.93705 0.67257] [ 2. 0.75597 0.785 0.72646] [ 3. 0.84544 0.90955 0.68958] [ 4. 0.66637 0.71065 0.75752] [ 5. 0.56253 0.63 0.78066] [ 6. 0.4769 0.60606 0.79114]
learn.save('16x16_8')
learn.load('16x16_8')
learn.set_data(get_data(24,bs))
learn.fit(1e-2, 1, wds=wd)
A Jupyter Widget
[ 0. 0.61302 0.67112 0.77321]
learn.unfreeze()
learn.fit(lr, 1, cycle_len=1, wds=wd)
A Jupyter Widget
[ 0. 0.49323 0.51107 0.82698]
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)
A Jupyter Widget
[ 0. 0.43563 0.47423 0.8373 ] [ 1. 0.46499 0.51238 0.82282] [ 2. 0.32453 0.40952 0.86141] [ 3. 0.47472 0.62542 0.78919] [ 4. 0.34746 0.43835 0.85566] [ 5. 0.24327 0.37241 0.87728] [ 6. 0.19959 0.35643 0.88214]
learn.save('24x24_8')
log_preds,y = learn.TTA()
preds = np.mean(np.exp(log_preds),0)metrics.log_loss(y,preds), accuracy(preds,y)
(0.32011880736219389, 0.89410000000000001)
learn.load('24x24_8')
learn.set_data(get_data(32,bs))
learn.fit(1e-2, 1, wds=wd)
A Jupyter Widget
[ 0. 0.28159 0.37857 0.87114]
learn.unfreeze()
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)
A Jupyter Widget
[ 0. 0.24112 0.3168 0.89587] [ 1. 0.28276 0.38893 0.87342] [ 2. 0.16777 0.29451 0.90378] [ 3. 0.27507 0.57156 0.81695] [ 4. 0.20369 0.35616 0.88786] [ 5. 0.12198 0.2949 0.90437] [ 6. 0.08032 0.27401 0.91169]
learn.fit(lr, 3, cycle_len=4, wds=wd)
A Jupyter Widget
[ 0. 0.23548 0.43751 0.86086] [ 1. 0.15057 0.32182 0.89616] [ 2. 0.08559 0.28051 0.91218] [ 3. 0.05464 0.26494 0.91555] [ 4. 0.18577 0.4403 0.86541] [ 5. 0.11875 0.32632 0.89943] [ 6. 0.06624 0.26536 0.91911] [ 7. 0.03661 0.26202 0.92059] [ 8. 0.16668 0.39448 0.87718] [ 9. 0.09441 0.31163 0.9019 ] [ 10. 0.0395 0.26366 0.92 ] [ 11. 0.02982 0.25696 0.92375]
log_preds,y = learn.TTA()
metrics.log_loss(y,np.exp(log_preds)), accuracy(log_preds,y)
(0.22727449907067793, 0.93049999999999999)
learn.save('32x32_8')