!pip install -q tflite-model-maker
import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
import tflite_model_maker
image_path = 'fall_dataset/'
data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 563, num_label: 2, labels: fall, not-fall.
train_data, validation_data = data.split(0.8)
model = image_classifier.create(train_data, model_spec='efficientnet_lite0', validation_data=validation_data, epochs=5)
print(model)
INFO:tensorflow:Retraining the models... Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2 (HubKer (None, 1280) 3413024 _________________________________________________________________ dropout (Dropout) (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 2) 2562 ================================================================= Total params: 3,415,586 Trainable params: 2,562 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/5
/home/bhavika/.local/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py:355: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. warnings.warn(
14/14 [==============================] - 11s 705ms/step - loss: 0.6039 - accuracy: 0.6942 - val_loss: 0.4400 - val_accuracy: 0.8958 Epoch 2/5 14/14 [==============================] - 9s 664ms/step - loss: 0.4081 - accuracy: 0.8795 - val_loss: 0.3727 - val_accuracy: 0.9167 Epoch 3/5 14/14 [==============================] - 10s 701ms/step - loss: 0.3128 - accuracy: 0.9487 - val_loss: 0.3448 - val_accuracy: 0.9375 Epoch 4/5 14/14 [==============================] - 10s 754ms/step - loss: 0.3099 - accuracy: 0.9442 - val_loss: 0.3324 - val_accuracy: 0.9375 Epoch 5/5 14/14 [==============================] - 10s 742ms/step - loss: 0.2923 - accuracy: 0.9554 - val_loss: 0.3308 - val_accuracy: 0.9271 <tensorflow_examples.lite.model_maker.core.task.image_classifier.ImageClassifier object at 0x7efd32382640>
model.export(export_dir='ai_models', tflite_filename='tflite-model-maker-falldetect-model-2oct.tflite')
INFO:tensorflow:Assets written to: /tmp/tmp_4e0792j/assets
INFO:tensorflow:Assets written to: /tmp/tmp_4e0792j/assets WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmp4rmkh08_/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmp4rmkh08_/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ai_models/tflite-model-maker-falldetect-model-2oct.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ai_models/tflite-model-maker-falldetect-model-2oct.tflite
accuracy = model.evaluate_tflite('ai_models/tflite-model-maker-falldetect-model-2oct.tflite', validation_data)
print(accuracy)
{'accuracy': 0.9292035398230089}
def get_label_color(val1, val2):
if val1 == val2:
return 'black'
else:
return 'red'
plt.figure(figsize=(20, 30))
predicts = model.predict_top_k(validation_data)
for i, (image, label) in enumerate(validation_data.gen_dataset().unbatch().take(100)):
ax = plt.subplot(10, 10, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy(), cmap=plt.cm.gray)
predict_label = predicts[i][0][0]
color = get_label_color(predict_label,
validation_data.index_to_label[label.numpy()])
ax.xaxis.label.set_color(color)
plt.xlabel('Predicted: %s' % predict_label)
plt.show()
print('Accuracy->', accuracy)
Accuracy-> {'accuracy': 0.9292035398230089}