#!/usr/bin/env python # coding: utf-8 # # Post-Training Quantization with TensorFlow Classification Model # # This example demonstrates how to quantize the OpenVINO model that was created in [301-tensorflow-training-openvino.ipynb](301-tensorflow-training-openvino.ipynb), to improve inference speed. Quantization is performed with [Post-Training Optimization Tool (POT)](https://docs.openvino.ai/nightly/pot_README.html). A custom dataloader and metric will be defined, and accuracy and performance will be computed for the original IR model and the quantized model. # ## Preparation # # The notebook requires that the training notebook has been run and that the Intermediate Representation (IR) models are created. If the IR models do not exist, running the next cell will run the training notebook. This will take a while. # In[ ]: from pathlib import Path import tensorflow as tf model_xml = Path("model/flower/flower_ir.xml") dataset_url = ( "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" ) data_dir = Path(tf.keras.utils.get_file("flower_photos", origin=dataset_url, untar=True)) if not model_xml.exists(): print("Executing training notebook. This will take a while...") get_ipython().run_line_magic('run', '301-tensorflow-training-openvino.ipynb') # ### Imports # # The Post Training Optimization API is implemented in the `compression` library. # In[ ]: import copy import os import sys import cv2 import matplotlib.pyplot as plt import numpy as np from addict import Dict from openvino.tools.pot.api import Metric, DataLoader from openvino.tools.pot.graph import load_model, save_model from openvino.tools.pot.graph.model_utils import compress_model_weights from openvino.tools.pot.engines.ie_engine import IEEngine from openvino.tools.pot.pipeline.initializer import create_pipeline from openvino.runtime import Core from PIL import Image sys.path.append("../utils") from notebook_utils import download_file # ### Settings # # In the next cell, the settings for running quantization are defined. The default settings use the _mixed_ preset and the _DefaultQuantization_ algorithm. This enables reasonably fast quantization, with possible drop in accuracy. The _performance_ preset can result in faster inference on the quantized model, the _AccuracyAwareQuantization_ algorithm quantizes the model to a defined maximal accuracy drop, which may not achieve the greatest performance boost but avoids further drop in accuracy. # # See the [Post-Training Optimization Best Practices](https://docs.openvino.ai/latest/pot_docs_BestPractices.html) page for more information about the configurable parameters and best practices for post-training quantization. # # The POT methods expect configuration dictionaries as arguments. They are defined in the cell below. # In[ ]: model_config = Dict( { "model_name": "flower", "model": "model/flower/flower_ir.xml", "weights": "model/flower/flower_ir.bin", } ) engine_config = Dict({"device": "CPU", "stat_requests_number": 2, "eval_requests_number": 2}) algorithms = [ { "name": "DefaultQuantization", "params": { "target_device": "CPU", "preset": "performance", "stat_subset_size": 1000, }, } ] # ### Create DataLoader Class # # OpenVINO's compression library contains a DataLoader class. The DataLoader defines how to load data and annotations. For the TensorFlow flowers dataset, images are stored in a directory per category. The DataLoader loads images from a given _data_source_ directory and assigns a label based on the position of the directory in _class_names_ (where class_names is a list of directory names in alphabetical order). # In[ ]: class ClassificationDataLoader(DataLoader): """ DataLoader for image data that is stored in a directory per category. For example, for categories _rose_ and _daisy_, rose images are expected in data_source/rose, daisy images in data_source/daisy. """ def __init__(self, data_source): """ :param data_source: path to data directory """ self.data_source = Path(data_source) self.dataset = [p for p in data_dir.glob("**/*") if p.suffix in (".png", ".jpg")] self.class_names = sorted([item.name for item in Path(data_dir).iterdir() if item.is_dir()]) def __len__(self): """ Returns the number of elements in the dataset """ return len(self.dataset) def __getitem__(self, index): """ Get item from self.dataset at the specified index. Returns (annotation, image), where annotation is a tuple (index, class_index) and image a preprocessed image in network shape """ if index >= len(self): raise IndexError filepath = self.dataset[index] annotation = (index, self.class_names.index(filepath.parent.name)) image = self._read_image(filepath) return annotation, image def _read_image(self, index): """ Read image at dataset[index] to memory, resize, convert to BGR and to network shape :param index: dataset index to read :return ndarray representation of image batch """ image = cv2.imread(os.path.join(self.data_source, index))[:, :, (2, 1, 0)] image = cv2.resize(image, (180, 180)).astype(np.float32) return image # ### Create Accuracy Metric Class # # The accuracy metric is defined as the number of correct predictions divided by the total number of predictions. It is used to validate the accuracy of the quantized model. # # The Accuracy class in this tutorial implements the `Metric` interface of the compression library. # In[ ]: class Accuracy(Metric): def __init__(self): super().__init__() self._name = "accuracy" self._matches = [] @property def value(self): """Returns accuracy metric value for the last model output.""" return {self._name: self._matches[-1]} @property def avg_value(self): """ Returns accuracy metric value for all model outputs. Results per image are stored in self._matches, where True means a correct prediction and False a wrong prediction. Accuracy is computed as the number of correct predictions divided by the total number of predictions. """ num_correct = np.count_nonzero(self._matches) return {self._name: num_correct / len(self._matches)} def update(self, output, target): """Updates prediction matches. :param output: model output :param target: annotations """ predict = np.argmax(output[0], axis=1) match = predict == target self._matches.append(match) def reset(self): """ Resets the Accuracy metric. This is a required method that should initialize all attributes to their initial value. """ self._matches = [] def get_attributes(self): """ Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}. Required attributes: 'direction': 'higher-better' or 'higher-worse' 'type': metric type """ return {self._name: {"direction": "higher-better", "type": "accuracy"}} # ## POT Optimization # # After creating the DataLoader and Metric classes, and defining the configuration settings for POT, we can start the quantization process. # In[ ]: # Step 1: Load the model model = load_model(model_config=model_config) original_model = copy.deepcopy(model) # Step 2: Initialize the data loader data_loader = ClassificationDataLoader(data_source=data_dir) # Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric # Compute metric results on original model metric = Accuracy() # Step 4: Initialize the engine for metric calculation and statistics collection engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric) # Step 5: Create a pipeline of compression algorithms pipeline = create_pipeline(algo_config=algorithms, engine=engine) # Step 6: Execute the pipeline compressed_model = pipeline.run(model=model) # Step 7 (Optional): Compress model weights quantized precision # in order to reduce the size of final .bin file compress_model_weights(model=compressed_model) # Step 8: Save the compressed model and get the path to the model compressed_model_paths = save_model( model=compressed_model, save_path=os.path.join(os.path.curdir, "model/optimized") ) compressed_model_xml = Path(compressed_model_paths[0]["model"]) print(f"The quantized model is stored in {compressed_model_xml}") # In[ ]: # Step 9 (Optional): Evaluate the original and compressed model. Print the results original_metric_results = pipeline.evaluate(original_model) if original_metric_results: print(f"Accuracy of the original model: {next(iter(original_metric_results.values())):.5f}") quantized_metric_results = pipeline.evaluate(compressed_model) if quantized_metric_results: print(f"Accuracy of the quantized model: {next(iter(quantized_metric_results.values())):.5f}") # ## Run Inference on Quantized Model # # Copy the preprocess function from the training notebook and run inference on the quantized model with Inference Engine. See the [OpenVINO API tutorial](../002-openvino-api/002-openvino-api.ipynb) for more information about running inference with Inference Engine Python API. # In[ ]: def pre_process_image(imagePath, img_height=180): # Model input format n, c, h, w = [1, 3, img_height, img_height] image = Image.open(imagePath) image = image.resize((h, w), resample=Image.BILINEAR) # Convert to array and change data layout from HWC to CHW image = np.array(image) input_image = image.reshape((n, h, w, c)) return input_image # In[ ]: # Load the optimized model and get the names of the input and output layer ie = Core() model_pot = ie.read_model(model="model/optimized/flower_ir.xml") compiled_model_pot = ie.compile_model(model=model_pot, device_name="CPU") input_layer = compiled_model_pot.input(0) output_layer = compiled_model_pot.output(0) # Get the class names: a list of directory names in alphabetical order class_names = sorted([item.name for item in Path(data_dir).iterdir() if item.is_dir()]) # Run inference on an input image... inp_img_url = ( "https://upload.wikimedia.org/wikipedia/commons/4/48/A_Close_Up_Photo_of_a_Dandelion.jpg" ) directory = "output" inp_file_name = "A_Close_Up_Photo_of_a_Dandelion.jpg" file_path = Path(directory)/Path(inp_file_name) # Download the image if it does not exist yet if not Path(inp_file_name).exists(): download_file(inp_img_url, inp_file_name, directory=directory) # Pre-process the image and get it ready for inference. input_image = pre_process_image(imagePath=file_path) print(f'input image shape: {input_image.shape}') print(f'input layer shape: {input_layer.shape}') res = compiled_model_pot([input_image])[output_layer] score = tf.nn.softmax(res[0]) # Show the results image = Image.open(file_path) plt.imshow(image) print( "This image most likely belongs to {} with a {:.2f} percent confidence.".format( class_names[np.argmax(score)], 100 * np.max(score) ) ) # ## Compare Inference Speed # # Measure inference speed with the [OpenVINO Benchmark App](https://docs.openvino.ai/latest/openvino_inference_engine_tools_benchmark_tool_README.html). # # Benchmark App is a command line tool that measures raw inference performance for a specified OpenVINO IR model. Run `benchmark_app --help` to see a list of available parameters. By default, Benchmark App tests the performance of the model specified with the `-m` parameter with asynchronous inference on CPU, for one minute. Use the `-d` parameter to test performance on a different device, for example an Intel integrated Graphics (iGPU), and `-t` to set the number of seconds to run inference. See the [documentation](https://docs.openvino.ai/latest/openvino_inference_engine_tools_benchmark_tool_README.html) for more information. # # In this tutorial, we use a wrapper function from [Notebook Utils](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/utils/notebook_utils.ipynb). It prints the `benchmark_app` command with the chosen parameters. # # In the next cells, inference speed will be measured for the original and quantized model on CPU. If an iGPU is available, inference speed will be measured for CPU+GPU as well. The number of seconds is set to 15. # # > NOTE: For the most accurate performance estimation, we recommended running `benchmark_app` in a terminal/command prompt after closing other applications. # In[ ]: # print the available devices on this system ie = Core() print("Device information:") print(ie.get_property("CPU", "FULL_DEVICE_NAME")) if "GPU" in ie.available_devices: print(ie.get_property("GPU", "FULL_DEVICE_NAME")) # In[ ]: # Original model - CPU get_ipython().system(' benchmark_app -m $model_xml -d CPU -t 15 -api async') # In[ ]: # Quantized model - CPU get_ipython().system(' benchmark_app -m $compressed_model_xml -d CPU -t 15 -api async') # **Benchmark on MULTI:CPU,GPU** # # With a recent Intel CPU, the best performance can often be achieved by doing inference on both the CPU and the iGPU, with OpenVINO's [Multi Device Plugin](https://docs.openvino.ai/2021.4/openvino_docs_IE_DG_supported_plugins_MULTI.html). It takes a bit longer to load a model on GPU than on CPU, so this benchmark will take a bit longer to complete than the CPU benchmark, when run for the first time. Benchmark App supports caching, by specifying the `--cdir` parameter. In the cells below, the model will cached to the `model_cache` directory. # In[ ]: # Original model - MULTI:CPU,GPU if "GPU" in ie.available_devices: get_ipython().system(' benchmark_app -m $model_xml -d MULTI:CPU,GPU -t 15 -api async') else: print("A supported integrated GPU is not available on this system.") # In[ ]: # Quantized model - MULTI:CPU,GPU if "GPU" in ie.available_devices: get_ipython().system(' benchmark_app -m $compressed_model_xml -d MULTI:CPU,GPU -t 15 -api async') else: print("A supported integrated GPU is not available on this system.") # In[ ]: # print the available devices on this system print("Device information:") print(ie.get_property("CPU", "FULL_DEVICE_NAME")) if "GPU" in ie.available_devices: print(ie.get_property("GPU", "FULL_DEVICE_NAME")) # **Original IR model - CPU** # In[ ]: benchmark_output = get_ipython().run_line_magic('sx', 'benchmark_app -m $model_xml -t 15 -api async') # Remove logging info from benchmark_app output and show only the results benchmark_result = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith(" ") or line=="")] print("\n".join(benchmark_result)) # **Quantized IR model - CPU** # In[ ]: benchmark_output = get_ipython().run_line_magic('sx', 'benchmark_app -m $compressed_model_xml -t 15 -api async') # Remove logging info from benchmark_app output and show only the results benchmark_result = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith(" ") or line=="")] print("\n".join(benchmark_result)) # **Original IR model - MULTI:CPU,GPU** # # With a recent Intel CPU, the best performance can often be achieved by doing inference on both the CPU and the iGPU, with OpenVINO's [Multi Device Plugin](https://docs.openvino.ai/latest/openvino_docs_OV_UG_Running_on_multiple_devices.html). It takes a bit longer to load a model on GPU than on CPU, so this benchmark will take a bit longer to complete than the CPU benchmark. # In[ ]: ie = Core() if "GPU" in ie.available_devices: benchmark_output = get_ipython().run_line_magic('sx', 'benchmark_app -m $model_xml -d MULTI:CPU,GPU -t 15 -api async') # Remove logging info from benchmark_app output and show only the results benchmark_result = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith(" ") or line=="")] print("\n".join(benchmark_result)) else: print("An integrated GPU is not available on this system.") # **Quantized IR model - MULTI:CPU,GPU** # In[ ]: ie = Core() if "GPU" in ie.available_devices: benchmark_output = get_ipython().run_line_magic('sx', 'benchmark_app -m $compressed_model_xml -d MULTI:CPU,GPU -t 15 -api async') # Remove logging info from benchmark_app output and show only the results benchmark_result = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith(" ") or line=="")] print("\n".join(benchmark_result)) else: print("An integrated GPU is not available on this system.")