CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
In [2]:
from fastai.conv_learner import *
from fastai.models.cifar10.wideresnet import wrn_22
torch.backends.cudnn.benchmark = True
PATH = Path("data/cifar10/")
os.makedirs(PATH,exist_ok=True)
In [3]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

bs=512
sz=32
In [4]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)
In [5]:
m = wrn_22()
In [6]:
learn = ConvLearner.from_model_data(m, data)
learn.crit = nn.CrossEntropyLoss()
learn.metrics = [accuracy]
wd=1e-4
lr=1.5
In [7]:
%time learn.fit(lr, 1, wds=wd, cycle_len=30, use_clr_beta=(20,20,0.95,0.85))
epoch      trn_loss   val_loss   accuracy                 
    0      1.456755   1.499619   0.5062    
    1      1.057333   1.157792   0.6116                   
    2      0.829041   0.783326   0.723                     
    3      0.66619    0.808943   0.7358                    
    4      0.570876   0.748631   0.7361                    
    5      0.495598   1.038086   0.6717                    
    6      0.448354   0.533581   0.8181                    
    7      0.415957   0.546836   0.816                     
    8      0.390528   0.61025    0.7827                    
    9      0.36144    0.751214   0.764                     
    10     0.351138   0.756213   0.7769                    
    11     0.33065    0.872244   0.7549                    
    12     0.323868   0.530568   0.8215                    
    13     0.301522   0.633277   0.8                       
    14     0.281426   0.609825   0.8141                    
    15     0.261843   0.792786   0.7706                    
    16     0.243936   0.727103   0.797                     
    17     0.233351   0.481732   0.8525                    
    18     0.219056   0.522896   0.8375                    
    19     0.196971   0.350686   0.8835                    
    20     0.180855   0.389286   0.8754                    
    21     0.150032   0.372619   0.883                     
    22     0.118364   0.255543   0.9182                    
    23     0.080524   0.22061    0.9311                     
    24     0.051989   0.207242   0.9347                     
    25     0.03802    0.21347    0.9368                     
    26     0.030564   0.211374   0.9381                     
    27     0.023117   0.214783   0.9398                     
    28     0.020133   0.21228    0.9421                     
    29     0.017761   0.212101   0.9423                     

CPU times: user 34min 14s, sys: 54min 24s, total: 1h 28min 38s
Wall time: 17min 16s
Out[7]:
[array([0.2121]), 0.9423000004768372]