#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
![]() |
![]() |
![]() |
TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models.
This tutorial demonstrates:
tf.keras
.!pip install tensorflow_hub
from __future__ import absolute_import, division, print_function
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers
tf.VERSION
For this example we'll use the TensorFlow flowers dataset:
data_root = tf.keras.utils.get_file(
'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
The simplest way to load this data into our model is using tf.keras.preprocessing.image.ImageDataGenerator
:
All of TensorFlow Hub's image modules expect float inputs in the [0, 1]
range. Use the ImageDataGenerator
's rescale
parameter to achieve this.
The image size will be handles later.
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root))
The resulting object is an iterator that returns image_batch, label_batch
pairs.
for image_batch,label_batch in image_data:
print("Image batch shape: ", image_batch.shape)
print("Labe batch shape: ", label_batch.shape)
break
Use hub.module
to load a mobilenet, and tf.keras.layers.Lambda
to wrap it up as a keras layer.
Any image classifier url from tfhub.dev will work here.
classifier_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2" #@param {type:"string"}
def classifier(x):
classifier_module = hub.Module(classifier_url)
return classifier_module(x)
IMAGE_SIZE = hub.get_expected_image_size(hub.Module(classifier_url))
classifier_layer = layers.Lambda(classifier, input_shape = IMAGE_SIZE+[3])
classifier_model = tf.keras.Sequential([classifier_layer])
classifier_model.summary()
Rebuild the data generator, with the output size set to match what's expected by the module.
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SIZE)
for image_batch,label_batch in image_data:
print("Image batch shape: ", image_batch.shape)
print("Labe batch shape: ", label_batch.shape)
break
When using Keras, TFHub modules need to be manually initialized.
import tensorflow.keras.backend as K
sess = K.get_session()
init = tf.global_variables_initializer()
sess.run(init)
Download a single image to try the model on.
import numpy as np
import PIL.Image as Image
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SIZE)
grace_hopper
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
Add a batch dimension, and pass the image to the model.
result = classifier_model.predict(grace_hopper[np.newaxis, ...])
result.shape
The result is a 1001 element vector of logits, rating the probability of each class for the image.
So the top class ID can be found with argmax:
predicted_class = np.argmax(result[0], axis=-1)
predicted_class
We have the predicted class ID,
Fetch the ImageNet
labels, and decode the predictions
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name)
Now run the classifier on the image batch.
result_batch = classifier_model.predict(image_batch)
labels_batch = imagenet_labels[np.argmax(result_batch, axis=-1)]
labels_batch
Now check how these predictions line up with the images:
plt.figure(figsize=(10,9))
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(labels_batch[n])
plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
See the LICENSE.txt
file for image attributions.
The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except "daisy").
Using tfhub it is simple to retrain the top layer of the model to recognize the classes in our dataset.
TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.
Any image feature vector url from tfhub.dev will work here.
feature_extractor_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2" #@param {type:"string"}
Create the module, and check the expected image size:
def feature_extractor(x):
feature_extractor_module = hub.Module(feature_extractor_url)
return feature_extractor_module(x)
IMAGE_SIZE = hub.get_expected_image_size(hub.Module(feature_extractor_url))
Ensure the data generator is generating images of the expected size:
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SIZE)
for image_batch,label_batch in image_data:
print("Image batch shape: ", image_batch.shape)
print("Labe batch shape: ", label_batch.shape)
break
Wrap the module in a keras layer.
features_extractor_layer = layers.Lambda(feature_extractor, input_shape=IMAGE_SIZE+[3])
Freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.
features_extractor_layer.trainable = False
Now wrap the hub layer in a tf.keras.Sequential
model, and add a new classification layer.
model = tf.keras.Sequential([
features_extractor_layer,
layers.Dense(image_data.num_classes, activation='softmax')
])
model.summary()
Initialize the TFHub module.
init = tf.global_variables_initializer()
sess.run(init)
Test run a single batch, to see that the result comes back with the expected shape.
result = model.predict(image_batch)
result.shape
Use compile to configure the training process:
model.compile(
optimizer=tf.train.AdamOptimizer(),
loss='categorical_crossentropy',
metrics=['accuracy'])
Now use the .fit
method to train the model.
To keep this example short train just a single epoch. To visualize the training progress during that epoch, use a custom callback to log the loss and accuract of each batch.
class CollectBatchStats(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_losses = []
self.batch_acc = []
def on_batch_end(self, batch, logs=None):
self.batch_losses.append(logs['loss'])
self.batch_acc.append(logs['acc'])
steps_per_epoch = image_data.samples//image_data.batch_size
batch_stats = CollectBatchStats()
model.fit((item for item in image_data), epochs=1,
steps_per_epoch=steps_per_epoch,
callbacks = [batch_stats])
Now after, even just a few training iterations, we can already see that the model is making progress on the task.
plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats.batch_losses)
plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats.batch_acc)
To redo the plot from before, first get the ordered list of class names:
label_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
label_names = np.array([key.title() for key, value in label_names])
label_names
Run the image batch through the model and comvert the indices to class names.
result_batch = model.predict(image_batch)
labels_batch = label_names[np.argmax(result_batch, axis=-1)]
labels_batch
Plot the result
plt.figure(figsize=(10,9))
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(labels_batch[n])
plt.axis('off')
_ = plt.suptitle("Model predictions")
Now that you've trained the model, export it as a saved model:
export_path = tf.contrib.saved_model.save_keras_model(model, "./saved_models")
export_path