#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial trains Darkent from Tensorflow Model Garden package (tf-models-official) to classify images in the cats_vs_dogs dataset.
Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
Dataset: cats_vs_dogs
This tutorial demonstrates how to:
! git clone -q https://github.com/tensorflow/models.git
! pip install -q -U tensorflow_datasets
! pip install -q --user -r models/official/requirements.txt
Note: Please restart runtime and continue with running the notebook
import os
import sys
import os
os.environ['PYTHONPATH'] += ":/content/models"
sys.path.append("/content/models")
import pprint
import logging
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from official import core
from official.vision.data import tfrecord_lib
from official.vision import configs
from official.vision.configs import common
from official.projects.yolo.common import registry_imports
from official.projects.yolo.serving import export_saved_model
from official.projects.yolo.serving import export_module_factory
from official.vision.serving import export_saved_model_lib
logging.disable(logging.WARNING)
pp = pprint.PrettyPrinter(indent=4)
%matplotlib inline
(train_ds, validation_ds, test_ds), ds_info = tfds.load(
name='cats_vs_dogs',
split=['train[:70%]', 'train[70%:90%]', 'train[90%:100%]'],
with_info=True)
label_info = ds_info.features['label']
ds_info
tfds.core.DatasetInfo( name='cats_vs_dogs', full_name='cats_vs_dogs/4.0.0', description=""" A large set of images of cats and dogs. There are 1738 corrupted images that are dropped. """, homepage='https://www.microsoft.com/en-us/download/details.aspx?id=54765', data_dir='/root/tensorflow_datasets/cats_vs_dogs/4.0.0', file_format=tfrecord, download_size=786.67 MiB, dataset_size=689.64 MiB, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=uint8), 'image/filename': Text(shape=(), dtype=string), 'label': ClassLabel(shape=(), dtype=int64, num_classes=2), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'train': <SplitInfo num_examples=23262, num_shards=8>, }, citation="""@Inproceedings (Conference){asirra-a-captcha-that-exploits-interest-aligned-manual-image-categorization, author = {Elson, Jeremy and Douceur, John (JD) and Howell, Jon and Saul, Jared}, title = {Asirra: A CAPTCHA that Exploits Interest-Aligned Manual Image Categorization}, booktitle = {Proceedings of 14th ACM Conference on Computer and Communications Security (CCS)}, year = {2007}, month = {October}, publisher = {Association for Computing Machinery, Inc.}, url = {https://www.microsoft.com/en-us/research/publication/asirra-a-captcha-that-exploits-interest-aligned-manual-image-categorization/}, edition = {Proceedings of 14th ACM Conference on Computer and Communications Security (CCS)}, }""", )
def process_record(record):
"""
Process a single record for TFRecords.
This function takes a record, typically containing image and label data,
and converts it into a TFRecord example. Detailed explaination is available here
https://www.tensorflow.org/api_docs/python/tf/train/Example
Args:
record (dict): A dictionary containing the record data with the following keys:
- 'image': A tensor representing the image data.
- 'label': A tensor representing the label associated with the image.
Returns:
tf.train.Example: A TFRecord example containing the processed data with
the following features:
- 'image/encoded': The encoded image data as a feature.
- 'image/class/label': The label data as a feature.
"""
keys_to_features = {
'image/encoded': tfrecord_lib.convert_to_feature(
tf.io.encode_jpeg(record['image']).numpy()),
'image/class/label': tfrecord_lib.convert_to_feature(
record['label'].numpy())
}
example = tf.train.Example(features=tf.train.Features(feature=keys_to_features))
return example
def write_tfrecords(dataset, output_path, num_shards=1):
"""
Write a dataset to TFRecords files.
This function takes a dataset and writes it to one or more TFRecords files,
splitting the data into shards if specified.
Args:
dataset (iterable): An iterable containing the data records to be written
to TFRecords. Each record should be in a format suitable for processing
with the 'process_record' function.
output_path (str): The base path where the TFRecords files will be saved.
If 'num_shards' is greater than 1, a unique suffix for each shard will
be added to the base path.
num_shards (int, optional): The number of TFRecords files to split the data
into. Defaults to 1, indicating no sharding.
Reuturns:
None
"""
writers = [
tf.io.TFRecordWriter(
output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))
for i in range(num_shards)
]
for idx, record in enumerate(dataset):
if idx % LOG_EVERY == 0:
print('On image %d' % idx)
tf_example = process_record(record)
writers[idx % num_shards].write(tf_example.SerializeToString())
LOG_EVERY = 1000
output_dir = './cat_vs_dogs_tfrecords/'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_train_tfrecs = output_dir + 'train'
write_tfrecords(train_ds, output_train_tfrecs,
num_shards=int(train_ds.cardinality().numpy() * 0.1))
On image 0 On image 1000 On image 2000 On image 3000 On image 4000 On image 5000 On image 6000 On image 7000 On image 8000 On image 9000 On image 10000 On image 11000 On image 12000 On image 13000 On image 14000 On image 15000 On image 16000
output_validation_tfrecs = output_dir + 'validation'
write_tfrecords(validation_ds, output_validation_tfrecs,
num_shards=int(validation_ds.cardinality().numpy() *0.1))
On image 0 On image 1000 On image 2000 On image 3000 On image 4000
output_test_tfrecs = output_dir + 'test'
write_tfrecords(test_ds, output_test_tfrecs,
num_shards=int(test_ds.cardinality().numpy() *0.1))
On image 0 On image 1000 On image 2000
exp_config = core.exp_factory.get_exp_config('darknet_classification')
BATCH_SIZE = 16
IMG_SIZE = 224
epochs = 10
steps_per_epoch = int(train_ds.cardinality().numpy() / BATCH_SIZE)
validation_steps = int(validation_ds.cardinality().numpy() / BATCH_SIZE)
num_steps = epochs * steps_per_epoch
lr = 0.012
warmpup_lr = 0.1 * lr
exp_config.task.model.input_size = [IMG_SIZE, IMG_SIZE, 3]
exp_config.task.model.num_classes = ds_info.features['label'].num_classes
exp_config.task.train_data.input_path = f'{output_train_tfrecs}*'
exp_config.task.train_data.global_batch_size = BATCH_SIZE
exp_config.task.validation_data.input_path = f'{output_validation_tfrecs}*'
exp_config.task.validation_data.global_batch_size = BATCH_SIZE
exp_config.trainer.checkpoint_interval = steps_per_epoch
exp_config.trainer.best_checkpoint_export_subdir = 'best_ckpt'
exp_config.trainer.optimizer_config.optimizer.type = 'sgd'
exp_config.trainer.optimizer_config.optimizer.sgd.momentum = 0.9
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = num_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = lr
exp_config.trainer.optimizer_config.warmup.type = 'linear'
exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = warmpup_lr
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = int(0.1 * steps_per_epoch)
exp_config.trainer.train_steps = num_steps
exp_config.trainer.steps_per_loop = steps_per_epoch
exp_config.trainer.validation_steps = validation_steps
exp_config.trainer.validation_interval = steps_per_epoch
exp_config.trainer.summary_interval = steps_per_epoch
# Detect hardware
try:
tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
tpu_resolver = None
gpus = tf.config.experimental.list_logical_devices("GPU")
# Select appropriate distribution strategy
if tpu_resolver:
tf.config.experimental_connect_to_cluster(tpu_resolver)
tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu_resolver)
print('Running on TPU ', tpu_resolver.cluster_spec().as_dict()['worker'])
elif len(gpus) > 1:
distribution_strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
print('Running on single GPU ', gpus[0].name)
else:
distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
print('Running on CPU')
print("Number of accelerators: ", distribution_strategy.num_replicas_in_sync)
Running on single GPU /device:GPU:0 Number of accelerators: 1
pprint.pprint(exp_config.as_dict())
{'runtime': {'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': False, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': None, 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'use_tpu_mp_strategy': False, 'worker_hosts': None}, 'task': {'allow_image_summary': False, 'differential_privacy_config': None, 'evaluation': {'precision_and_recall_thresholds': None, 'report_per_class_precision_and_recall': False, 'top_k': 5}, 'freeze_backbone': False, 'gradient_clip_norm': 0.0, 'init_checkpoint': '', 'logging_dir': None, 'losses': {'l2_weight_decay': 0.0, 'label_smoothing': 0.0, 'loss_weight': 1.0, 'one_hot': True, 'soft_labels': False, 'use_binary_cross_entropy': False}, 'model': {'add_head_batch_norm': False, 'backbone': {'darknet': {'depth_scale': 1.0, 'dilate': False, 'max_level': 5, 'min_level': 3, 'model_id': 'cspdarknet53', 'use_reorg_input': False, 'use_separable_conv': False, 'width_scale': 1.0}, 'type': 'darknet'}, 'dropout_rate': 0.0, 'input_size': [224, 224, 3], 'kernel_initializer': 'VarianceScaling', 'norm_activation': {'activation': 'relu', 'norm_epsilon': 0.001, 'norm_momentum': 0.99, 'use_sync_bn': True}, 'num_classes': 2}, 'name': None, 'train_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 16, 'image_field_key': 'image/encoded', 'input_path': './cat_vs_dogs_tfrecords/train*', 'is_multilabel': False, 'is_training': True, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'three_augment': False, 'trainer_id': None, 'weights': None}, 'validation_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 16, 'image_field_key': 'image/encoded', 'input_path': './cat_vs_dogs_tfrecords/validation*', 'is_multilabel': False, 'is_training': False, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'three_augment': False, 'trainer_id': None, 'weights': None}}, 'trainer': {'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': 'best_ckpt', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 1017, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': {'ema': None, 'learning_rate': {'cosine': {'alpha': 0.0, 'decay_steps': 10170, 'initial_learning_rate': 0.012, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': {'sgd': {'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': {'linear': {'name': 'linear', 'warmup_learning_rate': 0.0012000000000000001, 'warmup_steps': 101}, 'type': 'linear'}}, 'preemption_on_demand_checkpoint': True, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 1017, 'summary_interval': 1017, 'train_steps': 10170, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 1017, 'validation_steps': 290, 'validation_summary_subdir': 'validation'}}
(tfm.core.base_task.Task)
from the config_definitions.TaskConfig
.¶The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment
.
model_dir = './trained_model/'
with distribution_strategy.scope():
task = core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')
images.shape: (16, 224, 224, 3) images.dtype: tf.float32 labels.shape: (16,) labels.dtype: tf.int32
core.train_utils.serialize_config(exp_config, model_dir)
ds_info
(which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.¶label_info = ds_info.features['label']
def show_batch(images, labels, predictions=None):
plt.figure(figsize=(10, 10))
min = images.numpy().min()
max = images.numpy().max()
delta = max - min
for i in range(BATCH_SIZE):
plt.subplot(4, 4, i + 1)
plt.imshow((images[i]-min) / delta)
if predictions is None:
plt.title(label_info.int2str(labels[i]))
else:
if labels[i] == predictions[i]:
color = 'g'
else:
color = 'r'
plt.title(label_info.int2str(predictions[i]), color=color)
plt.axis("off")
plt.show()
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
show_batch(images, labels)
model, eval_logs = core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
restoring or initializing model... train | step: 0 | training until step 1017... train | step: 1017 | steps/sec: 6.4 | output: {'accuracy': 0.54013026, 'learning_rate': 0.011706339, 'top_5_accuracy': 1.0, 'training_loss': 0.7053006} saved checkpoint to ./trained_model/ckpt-1017. eval | step: 1017 | running 290 steps of evaluation... eval | step: 1017 | steps/sec: 17.6 | eval time: 16.4 sec | output: {'accuracy': 0.57693964, 'steps_per_second': 17.634644012635317, 'top_5_accuracy': 1.0, 'validation_loss': 0.67286325} train | step: 1017 | training until step 2034... train | step: 2034 | steps/sec: 8.6 | output: {'accuracy': 0.58898723, 'learning_rate': 0.010854102, 'top_5_accuracy': 1.0, 'training_loss': 0.6713662} saved checkpoint to ./trained_model/ckpt-2034. eval | step: 2034 | running 290 steps of evaluation... eval | step: 2034 | steps/sec: 37.3 | eval time: 7.8 sec | output: {'accuracy': 0.6153017, 'steps_per_second': 37.2743481981908, 'top_5_accuracy': 1.0, 'validation_loss': 0.652928} train | step: 2034 | training until step 3051... train | step: 3051 | steps/sec: 9.3 | output: {'accuracy': 0.6157817, 'learning_rate': 0.009526712, 'top_5_accuracy': 1.0, 'training_loss': 0.65529084} saved checkpoint to ./trained_model/ckpt-3051. eval | step: 3051 | running 290 steps of evaluation... eval | step: 3051 | steps/sec: 37.2 | eval time: 7.8 sec | output: {'accuracy': 0.51681036, 'steps_per_second': 37.18629141886086, 'top_5_accuracy': 1.0, 'validation_loss': 1.0405935} train | step: 3051 | training until step 4068... train | step: 4068 | steps/sec: 9.3 | output: {'accuracy': 0.6411013, 'learning_rate': 0.007854102, 'top_5_accuracy': 1.0, 'training_loss': 0.63446975} saved checkpoint to ./trained_model/ckpt-4068. eval | step: 4068 | running 290 steps of evaluation... eval | step: 4068 | steps/sec: 37.3 | eval time: 7.8 sec | output: {'accuracy': 0.68297416, 'steps_per_second': 37.306989215139936, 'top_5_accuracy': 1.0, 'validation_loss': 0.59876406} train | step: 4068 | training until step 5085... train | step: 5085 | steps/sec: 9.3 | output: {'accuracy': 0.6549287, 'learning_rate': 0.0059999996, 'top_5_accuracy': 1.0, 'training_loss': 0.6212429} saved checkpoint to ./trained_model/ckpt-5085. eval | step: 5085 | running 290 steps of evaluation... eval | step: 5085 | steps/sec: 37.7 | eval time: 7.7 sec | output: {'accuracy': 0.67349136, 'steps_per_second': 37.65721170518001, 'top_5_accuracy': 1.0, 'validation_loss': 0.59417003} train | step: 5085 | training until step 6102... train | step: 6102 | steps/sec: 9.3 | output: {'accuracy': 0.6778515, 'learning_rate': 0.004145897, 'top_5_accuracy': 1.0, 'training_loss': 0.5970689} saved checkpoint to ./trained_model/ckpt-6102. eval | step: 6102 | running 290 steps of evaluation... eval | step: 6102 | steps/sec: 37.3 | eval time: 7.8 sec | output: {'accuracy': 0.68168104, 'steps_per_second': 37.324507959229, 'top_5_accuracy': 1.0, 'validation_loss': 0.6371893} train | step: 6102 | training until step 7119... train | step: 7119 | steps/sec: 9.3 | output: {'accuracy': 0.7034169, 'learning_rate': 0.0024732884, 'top_5_accuracy': 1.0, 'training_loss': 0.56852204} saved checkpoint to ./trained_model/ckpt-7119. eval | step: 7119 | running 290 steps of evaluation... eval | step: 7119 | steps/sec: 37.4 | eval time: 7.7 sec | output: {'accuracy': 0.7521552, 'steps_per_second': 37.433325893948016, 'top_5_accuracy': 1.0, 'validation_loss': 0.52360195} train | step: 7119 | training until step 8136... train | step: 8136 | steps/sec: 9.3 | output: {'accuracy': 0.7234513, 'learning_rate': 0.0011458977, 'top_5_accuracy': 1.0, 'training_loss': 0.54219157} saved checkpoint to ./trained_model/ckpt-8136. eval | step: 8136 | running 290 steps of evaluation... eval | step: 8136 | steps/sec: 37.5 | eval time: 7.7 sec | output: {'accuracy': 0.787931, 'steps_per_second': 37.471471146938974, 'top_5_accuracy': 1.0, 'validation_loss': 0.45562845} train | step: 8136 | training until step 9153... train | step: 9153 | steps/sec: 9.3 | output: {'accuracy': 0.7457596, 'learning_rate': 0.0002936611, 'top_5_accuracy': 1.0, 'training_loss': 0.50991994} saved checkpoint to ./trained_model/ckpt-9153. eval | step: 9153 | running 290 steps of evaluation... eval | step: 9153 | steps/sec: 37.7 | eval time: 7.7 sec | output: {'accuracy': 0.80625, 'steps_per_second': 37.66015917575601, 'top_5_accuracy': 1.0, 'validation_loss': 0.41287535} train | step: 9153 | training until step 10170... train | step: 10170 | steps/sec: 9.3 | output: {'accuracy': 0.7517822, 'learning_rate': 0.0, 'top_5_accuracy': 1.0, 'training_loss': 0.5036062} saved checkpoint to ./trained_model/ckpt-10170. eval | step: 10170 | running 290 steps of evaluation... eval | step: 10170 | steps/sec: 37.6 | eval time: 7.7 sec | output: {'accuracy': 0.81314653, 'steps_per_second': 37.58710303226575, 'top_5_accuracy': 1.0, 'validation_loss': 0.40270108} eval | step: 10170 | running 290 steps of evaluation... eval | step: 10170 | steps/sec: 37.3 | eval time: 7.8 sec | output: {'accuracy': 0.81314653, 'steps_per_second': 37.3229583933457, 'top_5_accuracy': 1.0, 'validation_loss': 0.40270108}
EXPORT_DIR_PATH = "./exported_model/"
!python -m official.projects.yolo.serving.export_saved_model \
--experiment="darknet_classification" \
--export_dir=$EXPORT_DIR_PATH/ \
--checkpoint_path=$model_dir \
--config_file=$model_dir/params.yaml \
--batch_size=$BATCH_SIZE \
--input_type="image_tensor" \
--input_image_size=$IMG_SIZE,$IMG_SIZE
2023-11-14 21:36:10.722529: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2023-11-14 21:36:18.463386: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0. I1114 21:36:49.038166 140374729912320 signature_serialization.py:148] Function `serve` contains input name(s) resource with unsupported characters which will be renamed to dense_biasadd_readvariableop_resource in the SavedModel. I1114 21:37:34.619264 140374729912320 save.py:274] Found untraced functions such as serve_eval, conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, _jit_compiled_convolution_op, conv_bn_layer_call_fn while saving (showing 5 of 415). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: ./exported_model//saved_model/assets I1114 21:37:46.884536 140374729912320 builder_impl.py:804] Assets written to: ./exported_model//saved_model/assets I1114 21:37:47.379255 140374729912320 fingerprinting_utils.py:48] Writing fingerprint to ./exported_model//saved_model/fingerprint.pb I1114 21:37:48.188559 140374729912320 train_utils.py:400] Saving experiment configuration to ./exported_model//params.yaml
imported = tf.saved_model.load('/content/exported_model/saved_model')
model_fn = imported.signatures['serving_default']
def resize_image(record):
image = tf.image.resize(record['image'], size=(IMG_SIZE, IMG_SIZE))
image = tf.cast(image, tf.uint8)
return image, record['label']
test_ds_resized = test_ds.map(resize_image).shuffle(100)
test_ds_batched = test_ds_resized.batch(BATCH_SIZE)
for images, labels in test_ds_batched.take(1):
predictions = model_fn(inputs=images)['logits']
predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, predictions)