#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('reload_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') import os os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; os.environ["CUDA_VISIBLE_DEVICES"]="0"; import sys # In[2]: import ktrain from ktrain import vision import tensorflow.keras.backend as K # In[3]: # load cifar10 and manually standaridize from tensorflow.keras.datasets import cifar10 from tensorflow.keras.utils import to_categorical (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train = x_train.astype('float32') x_train = (x_train - x_train.mean(axis=0)) / (x_train.std(axis=0)) x_test = x_test.astype('float32') x_test = (x_test - x_test.mean(axis=0)) / (x_test.std(axis=0)) y_train = to_categorical(y_train) y_test = to_categorical(y_test) # In[4]: x_train[0].shape # In[11]: #input_shape = (3, 32, 32) if K.image_dim_ordering() == 'th' else (32, 32, 3) input_shape = (32, 32, 3) # In[9]: # define data augmentation turn featurewise* off, since we've manually standardized above data_aug = vision.get_data_aug(featurewise_center=False, featurewise_std_normalization=False, horizontal_flip=True, width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.0, rotation_range=10) # In[12]: # load training and validation data as generators with data augmentation (train_data, val_data, preproc) = vision.images_from_array(x_train, y_train, validation_data=(x_test, y_test), data_aug=data_aug) # In[13]: # let's examine the available image classifiers vision.print_image_classifiers() # In[14]: # load a 22-layer Wide ResNet model = vision.image_classifier('wrn22', train_data, val_data) # In[16]: # get a Learner object to be used in training learner = ktrain.get_learner(model, train_data=train_data, val_data=val_data, workers=8, use_multiprocessing=True, batch_size=64) # In[17]: # find a good learning rate learner.lr_find() # In[18]: learner.lr_plot() # In[19]: # fit using onecycle policy learner.fit_onecycle(1e-3, 30)