Real-time object detection is often used as a key component in computer vision systems. Applications that use real-time object detection models include video analytics, robotics, autonomous vehicles, multi-object tracking and object counting, medical image analysis, and many others.
This tutorial demonstrates step-by-step instructions on how to run and optimize PyTorch YOLOv12 with OpenVINO. We consider the steps required for object detection scenario. You can find more details about model on model page in Ultralytics documentation
The tutorial consists of the following steps:
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Generally, PyTorch models represent an instance of the torch.nn.Module
class, initialized by a state dictionary with model weights.
We will use the YOLOv12 nano model (also known as yolo12n
) pre-trained on a COCO dataset, which is available in this repo. Similar steps are also applicable to other YOLOv12 models.
Typical steps to obtain a pre-trained model:
In this case, Ultralytics provides an API that enables converting the YOLOv12 model to OpenVINO IR. Therefore, we do not need to do these steps manually.
import platform
if platform.system() == "Darwin":
%pip install -q "numpy<2.0.0"
%pip install -qU "openvino>=2025.1.0" "nncf>=2.16.0"
%pip install -q "torch>=2.1" "torchvision>=0.16" tqdm opencv-python --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "ultralytics==8.3.142"
Note: you may need to restart the kernel to use updated packages. Note: you may need to restart the kernel to use updated packages. Note: you may need to restart the kernel to use updated packages.
Import required utility functions.
The lower cell will download the notebook_utils
Python module from GitHub.
from pathlib import Path
import requests
if not Path("notebook_utils.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)
from notebook_utils import download_file, VideoPlayer, device_widget, quantization_widget
# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry
from notebook_utils import collect_telemetry
collect_telemetry("yolov12-object-detection.ipynb")
# Download a test sample
IMAGE_PATH = Path("./data/coco_bike.jpg")
if not IMAGE_PATH.exists():
download_file(
url="https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco_bike.jpg",
filename=IMAGE_PATH.name,
directory=IMAGE_PATH.parent,
)
There are several models available in the original repository, targeted for different tasks. For loading the model, required to specify a path to the model checkpoint. It can be some local path or name available on models hub (in this case model checkpoint will be downloaded automatically).
You can select one of represented model using widget bellow:
import ipywidgets as widgets
model_id = ["yolo12n", "yolo12s", "yolo12m", "yolo12l", "yolo12x"]
model_name = widgets.Dropdown(options=model_id, value=model_id[0], description="Model")
model_name
Dropdown(description='Model', options=('yolo12n', 'yolo12s', 'yolo12m', 'yolo12l', 'yolo12x'), value='yolo12n'…
Making prediction, the model accepts a path to input image and returns list with Results class object. Results contains boxes for object detection model. Also it contains utilities for processing results, for example, plot()
method for drawing.
Let us consider the examples:
from PIL import Image
from ultralytics import YOLO
DET_MODEL_NAME = model_name.value
det_model = YOLO(f"{DET_MODEL_NAME}.pt")
det_model.to("cpu")
label_map = det_model.model.names
res = det_model(IMAGE_PATH)
Image.fromarray(res[0].plot()[:, :, ::-1])
image 1/1 /home/maleksandr/test_notebooks/yolo-device/openvino_notebooks/notebooks/yolov12-optimization/data/coco_bike.jpg: 480x640 2 bicycles, 2 cars, 1 dog, 94.3ms Speed: 2.6ms preprocess, 94.3ms inference, 1.0ms postprocess per image at shape (1, 3, 480, 640)
Ultralytics provides API for convenient model exporting to different formats including OpenVINO IR. model.export
is responsible for model conversion. We need to specify the format, and additionally, we can preserve dynamic shapes in the model.
# object detection model
det_model_path = Path(f"{DET_MODEL_NAME}_openvino_model/{DET_MODEL_NAME}.xml")
if not det_model_path.exists():
det_model.export(format="openvino", dynamic=True, half=True)
Ultralytics 8.3.142 🚀 Python-3.10.12 torch-2.7.0+cpu CPU (Intel Core(TM) i9-10980XE 3.00GHz) PyTorch: starting from 'yolo12n.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 84, 8400) (5.3 MB) OpenVINO: starting export with openvino 2025.1.0-18503-6fec06580ab-releases/2025/1... OpenVINO: export success ✅ 3.4s, saved as 'yolo12n_openvino_model/' (5.7 MB) Export complete (3.6s) Results saved to /home/maleksandr/test_notebooks/yolo-device/openvino_notebooks/notebooks/yolov12-optimization Predict: yolo predict task=detect model=yolo12n_openvino_model imgsz=640 half Validate: yolo val task=detect model=yolo12n_openvino_model imgsz=640 data=None half Visualize: https://netron.app
We can reuse the base model pipeline specifying Intel devices (intel:gpu, intel:npu, intel:cpu) when running inference with OpenVINO.
Select device from dropdown list for running inference using OpenVINO
device = device_widget()
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
Now, once we have defined preprocessing and postprocessing steps, we are ready to check model prediction for object detection.
det_model = YOLO(det_model_path.parent, task="detect")
res = det_model(IMAGE_PATH, device=f"intel:{device.value.lower()}")
Image.fromarray(res[0].plot()[:, :, ::-1])
Loading yolo12n_openvino_model for OpenVINO inference... WARNING ⚠️ OpenVINO device 'AUTO' not available. Using 'AUTO' instead. Using OpenVINO LATENCY mode for batch=1 inference... image 1/1 /home/maleksandr/test_notebooks/yolo-device/openvino_notebooks/notebooks/yolov12-optimization/data/coco_bike.jpg: 480x640 2 bicycles, 2 cars, 1 dog, 90.0ms Speed: 3.7ms preprocess, 90.0ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)
NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop. We will use 8-bit quantization in post-training mode (without the fine-tuning pipeline) to optimize model.
The optimization process contains the following steps:
nncf.quantize
for getting an optimized model.openvino.save_model
function.Please select below whether you would like to run quantization to improve model inference speed.
int8_model_det_path = Path(f"{DET_MODEL_NAME}_openvino_model_int8/{DET_MODEL_NAME}.xml")
quantized_det_model = None
to_quantize = quantization_widget()
to_quantize
Checkbox(value=True, description='Quantization')
Let's load skip magic
extension to skip quantization if to_quantize
is not selected
if not Path("skip_kernel_extension.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py",
)
open("skip_kernel_extension.py", "w").write(r.text)
%load_ext skip_kernel_extension
%%skip not $to_quantize.value
from ultralytics.utils import DEFAULT_CFG
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from zipfile import ZipFile
from ultralytics.data.utils import DATASETS_DIR
DATA_URL = "http://images.cocodataset.org/zips/val2017.zip"
LABELS_URL = "https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels-segments.zip"
CFG_URL = "https://raw.githubusercontent.com/ultralytics/ultralytics/v8.1.0/ultralytics/cfg/datasets/coco.yaml"
OUT_DIR = DATASETS_DIR
DATA_PATH = OUT_DIR / "val2017.zip"
LABELS_PATH = OUT_DIR / "coco2017labels-segments.zip"
CFG_PATH = OUT_DIR / "coco.yaml"
if not int8_model_det_path.exists():
if not (OUT_DIR / "coco/labels").exists():
download_file(DATA_URL, DATA_PATH.name, DATA_PATH.parent)
download_file(LABELS_URL, LABELS_PATH.name, LABELS_PATH.parent)
download_file(CFG_URL, CFG_PATH.name, CFG_PATH.parent)
with ZipFile(LABELS_PATH, "r") as zip_ref:
zip_ref.extractall(OUT_DIR)
with ZipFile(DATA_PATH, "r") as zip_ref:
zip_ref.extractall(OUT_DIR / "coco/images")
args = get_cfg(cfg=DEFAULT_CFG)
args.data = str(CFG_PATH)
det_validator = det_model.task_map[det_model.task]["validator"](args=args)
det_validator.data = check_det_dataset(args.data)
det_validator.stride = 32
det_data_loader = det_validator.get_dataloader(OUT_DIR / "coco", 1)
det_validator.is_coco = True
det_validator.class_map = coco80_to_coco91_class()
det_validator.names = label_map
det_validator.metrics.names = det_validator.names
det_validator.nc = 80
val: Fast image access ✅ (ping: 0.0±0.0 ms, read: 2964.0±1417.7 MB/s, size: 91.1 KB)
val: Scanning /home/maleksandr/test_notebooks/olmocr-check/openvino_notebooks/notebooks/yolov11-optimization/datasets/co
Reuse validation dataloader in accuracy testing for quantization.
For that, it should be wrapped into the nncf.Dataset
object and define a transformation function for getting only input tensors.
%%skip not $to_quantize.value
import nncf
if not int8_model_det_path.exists():
def transform_fn(data_item:dict):
"""
Quantization transform function. Extracts and preprocess input data from dataloader item for quantization.
Parameters:
data_item: Dict with data item produced by DataLoader during iteration
Returns:
input_tensor: Input data for quantization
"""
input_tensor = det_validator.preprocess(data_item)['img'].numpy()
return input_tensor
quantization_dataset = nncf.Dataset(det_data_loader, transform_fn)
The nncf.quantize
function provides an interface for model quantization. It requires an instance of the OpenVINO Model and quantization dataset.
Optionally, some additional parameters for the configuration quantization process (number of samples for quantization, preset, ignored scope, etc.) can be provided. Ultralytics modes contain non-ReLU activation functions, which require asymmetric quantization of activations. To achieve a better result, we will use a mixed
quantization preset. It provides symmetric quantization of weights and asymmetric quantization of activations. For more accurate results, we should keep the operation in the postprocessing subgraph in floating point precision, using the ignored_scope
parameter.
Note: Model post-training quantization is time-consuming process. Be patient, it can take several minutes depending on your hardware.
%%skip not $to_quantize.value
import shutil
import openvino as ov
core = ov.Core()
det_ov_model = core.read_model(det_model_path)
if not int8_model_det_path.exists():
ignored_scope = nncf.IgnoredScope( # post-processing
subgraphs=[
nncf.Subgraph(inputs=[f"__module.model.21/aten::cat/Concat",
f"__module.model.21/aten::cat/Concat_1",
f"__module.model.21/aten::cat/Concat_2"],
outputs=[f"__module.model.21/aten::cat/Concat_7"])
]
)
# Detection model
quantized_det_model = nncf.quantize(
det_ov_model,
quantization_dataset,
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=ignored_scope
)
print(f"Quantized detection model will be saved to {int8_model_det_path}")
ov.save_model(quantized_det_model, str(int8_model_det_path))
shutil.copy(det_model_path.parent / "metadata.yaml", int8_model_det_path.parent / "metadata.yaml")
WARNING:nncf:NNCF provides best results with torch==2.6.*, while current torch version is 2.7.0+cpu. If you encounter issues, consider switching to torch==2.6.*
/home/maleksandr/test_notebooks/yolo-device/openvino_notebooks/venv/lib/python3.10/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`. warnings.warn(
INFO:nncf:108 ignored nodes were found by subgraphs in the NNCFGraph INFO:nncf:Not adding activation input quantizer for operation: 281 __module.model.21/aten::cat/Concat INFO:nncf:Not adding activation input quantizer for operation: 315 __module.model.21/aten::view/Reshape_3 INFO:nncf:Not adding activation input quantizer for operation: 586 __module.model.21/aten::cat/Concat_1 INFO:nncf:Not adding activation input quantizer for operation: 601 __module.model.21/aten::view/Reshape_4 INFO:nncf:Not adding activation input quantizer for operation: 677 __module.model.21/aten::cat/Concat_2 INFO:nncf:Not adding activation input quantizer for operation: 680 __module.model.21/aten::view/Reshape_5 INFO:nncf:Not adding activation input quantizer for operation: 352 __module.model.21/aten::cat/Concat_4 INFO:nncf:Not adding activation input quantizer for operation: 386 __module.model.21/prim::ListUnpack INFO:nncf:Not adding activation input quantizer for operation: 423 __module.model.21.dfl/aten::view/Reshape INFO:nncf:Not adding activation input quantizer for operation: 424 __module.model.21/aten::sigmoid/Sigmoid INFO:nncf:Not adding activation input quantizer for operation: 461 __module.model.21.dfl/aten::transpose/Transpose INFO:nncf:Not adding activation input quantizer for operation: 492 __module.model.21.dfl/aten::softmax/Softmax INFO:nncf:Not adding activation input quantizer for operation: 519 __module.model.21.dfl.conv/aten::_convolution/Convolution INFO:nncf:Not adding activation input quantizer for operation: 544 __module.model.21.dfl/aten::view/Reshape_1 INFO:nncf:Not adding activation input quantizer for operation: 283 __module.model.21/prim::ListUnpack/VariadicSplit INFO:nncf:Not adding activation input quantizer for operation: 317 __module.model.21/aten::sub/Subtract INFO:nncf:Not adding activation input quantizer for operation: 318 __module.model.21/aten::add/Add_6 INFO:nncf:Not adding activation input quantizer for operation: 354 __module.model.21/aten::add/Add_7 388 __module.model.21/aten::div/Divide INFO:nncf:Not adding activation input quantizer for operation: 355 __module.model.21/aten::sub/Subtract_1 INFO:nncf:Not adding activation input quantizer for operation: 389 __module.model.21/aten::cat/Concat_5 INFO:nncf:Not adding activation input quantizer for operation: 426 __module.model.21/aten::mul/Multiply_3 INFO:nncf:Not adding activation input quantizer for operation: 462 __module.model.21/aten::cat/Concat_7
Output()
Output()
Quantized detection model will be saved to yolo12n_openvino_model_int8/yolo12n.xml
nncf.quantize
returns the OpenVINO Model class instance, which is suitable for loading on a device for making predictions. INT8
model input data and output result formats have no difference from the floating point model representation. Therefore, we can reuse the same detect
function defined above for getting the INT8
model result on the image.
%%skip not $to_quantize.value
display(device)
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
%%skip not $to_quantize.value
det_model = YOLO(int8_model_det_path.parent, task="detect")
res = det_model(IMAGE_PATH, device=f"intel:{device.value.lower()}")
display(Image.fromarray(res[0].plot()[:, :, ::-1]))
Loading yolo12n_openvino_model_int8 for OpenVINO inference... WARNING ⚠️ OpenVINO device 'AUTO' not available. Using 'AUTO' instead. Using OpenVINO LATENCY mode for batch=1 inference... image 1/1 /home/maleksandr/test_notebooks/yolo-device/openvino_notebooks/notebooks/yolov12-optimization/data/coco_bike.jpg: 480x640 2 bicycles, 2 cars, 1 dog, 137.3ms Speed: 4.6ms preprocess, 137.3ms inference, 1.3ms postprocess per image at shape (1, 3, 480, 640)
Finally, use the OpenVINO Benchmark Tool to measure the inference performance of the FP32
and INT8
models.
Note: For more accurate performance, it is recommended to run
benchmark_app
in a terminal/command prompt after closing other applications. Runbenchmark_app -m <model_path> -d CPU -shape "<input_shape>"
to benchmark async inference on CPU on specific input data shape for one minute. ChangeCPU
toGPU
to benchmark on GPU. Runbenchmark_app --help
to see an overview of all command-line options.
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
if int8_model_det_path.exists():
# Inference FP32 model (OpenVINO IR)
!benchmark_app -m $det_model_path -d $device.value -api async -shape "[1,3,640,640]" -t 15
[Step 1/11] Parsing and validating input arguments [ INFO ] Parsing input parameters [Step 2/11] Loading OpenVINO Runtime [ INFO ] OpenVINO: [ INFO ] Build ................................. 2025.1.0-18503-6fec06580ab-releases/2025/1 [ INFO ] [ INFO ] Device info: [ INFO ] AUTO [ INFO ] Build ................................. 2025.1.0-18503-6fec06580ab-releases/2025/1 [ INFO ] [ INFO ] [Step 3/11] Setting device configuration [ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT. [Step 4/11] Reading model files [ INFO ] Loading model files [ INFO ] Read model took 31.07 ms [ INFO ] Original model I/O parameters: [ INFO ] Model inputs: [ INFO ] x (node: x) : f32 / [...] / [?,3,?,?] [ INFO ] Model outputs: [ INFO ] ***NO_NAME*** (node: __module.model.21/aten::cat/Concat_7) : f32 / [...] / [?,84,21..] [Step 5/11] Resizing model to match image sizes and given batch [ INFO ] Model batch size: 1 [ INFO ] Reshaping model: 'x': [1,3,640,640] [ INFO ] Reshape model took 18.11 ms [Step 6/11] Configuring input of the model [ INFO ] Model inputs: [ INFO ] x (node: x) : u8 / [N,C,H,W] / [1,3,640,640] [ INFO ] Model outputs: [ INFO ] ***NO_NAME*** (node: __module.model.21/aten::cat/Concat_7) : f32 / [...] / [1,84,8400] [Step 7/11] Loading the model to the device [ INFO ] Compile model took 468.31 ms [ INFO ] Start of compilation memory usage: Peak 416364 KB [ INFO ] End of compilation memory usage: Peak 3736212 KB [ INFO ] Compile model ram used 3319848 KB [Step 8/11] Querying optimal runtime parameters [ INFO ] Model: [ INFO ] NETWORK_NAME: Model0 [ INFO ] EXECUTION_DEVICES: ['CPU'] [ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT [ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12 [ INFO ] MULTI_DEVICE_PRIORITIES: CPU [ INFO ] CPU: [ INFO ] CPU_DENORMALS_OPTIMIZATION: False [ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0 [ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32 [ INFO ] ENABLE_CPU_PINNING: True [ INFO ] ENABLE_CPU_RESERVATION: False [ INFO ] ENABLE_HYPER_THREADING: True [ INFO ] EXECUTION_DEVICES: ['CPU'] [ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE [ INFO ] INFERENCE_NUM_THREADS: 36 [ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'> [ INFO ] KEY_CACHE_GROUP_SIZE: 0 [ INFO ] KEY_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] KV_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] LOG_LEVEL: Level.NO [ INFO ] MODEL_DISTRIBUTION_POLICY: set() [ INFO ] NETWORK_NAME: Model0 [ INFO ] NUM_STREAMS: 12 [ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12 [ INFO ] PERFORMANCE_HINT: THROUGHPUT [ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0 [ INFO ] PERF_COUNT: NO [ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE [ INFO ] VALUE_CACHE_GROUP_SIZE: 0 [ INFO ] VALUE_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] MODEL_PRIORITY: Priority.MEDIUM [ INFO ] LOADED_FROM_CACHE: False [ INFO ] PERF_COUNT: False [Step 9/11] Creating infer requests and preparing input tensors [ WARNING ] No input files were given for input 'x'!. This input will be filled with random values! [ INFO ] Fill input 'x' with random values [Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration) [ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop). [ INFO ] First inference took 42.83 ms [Step 11/11] Dumping statistics report [ INFO ] Execution Devices:['CPU'] [ INFO ] Count: 2028 iterations [ INFO ] Duration: 15054.16 ms [ INFO ] Latency: [ INFO ] Median: 87.99 ms [ INFO ] Average: 88.70 ms [ INFO ] Min: 40.06 ms [ INFO ] Max: 118.52 ms [ INFO ] Throughput: 134.71 FPS
if int8_model_det_path.exists():
# Inference INT8 model (OpenVINO IR)
!benchmark_app -m $int8_model_det_path -d $device.value -api async -shape "[1,3,640,640]" -t 15
[Step 1/11] Parsing and validating input arguments [ INFO ] Parsing input parameters [Step 2/11] Loading OpenVINO Runtime [ INFO ] OpenVINO: [ INFO ] Build ................................. 2025.1.0-18503-6fec06580ab-releases/2025/1 [ INFO ] [ INFO ] Device info: [ INFO ] AUTO [ INFO ] Build ................................. 2025.1.0-18503-6fec06580ab-releases/2025/1 [ INFO ] [ INFO ] [Step 3/11] Setting device configuration [ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT. [Step 4/11] Reading model files [ INFO ] Loading model files [ INFO ] Read model took 48.72 ms [ INFO ] Original model I/O parameters: [ INFO ] Model inputs: [ INFO ] x (node: x) : f32 / [...] / [?,3,?,?] [ INFO ] Model outputs: [ INFO ] ***NO_NAME*** (node: __module.model.21/aten::cat/Concat_7) : f32 / [...] / [?,84,21..] [Step 5/11] Resizing model to match image sizes and given batch [ INFO ] Model batch size: 1 [ INFO ] Reshaping model: 'x': [1,3,640,640] [ INFO ] Reshape model took 28.36 ms [Step 6/11] Configuring input of the model [ INFO ] Model inputs: [ INFO ] x (node: x) : u8 / [N,C,H,W] / [1,3,640,640] [ INFO ] Model outputs: [ INFO ] ***NO_NAME*** (node: __module.model.21/aten::cat/Concat_7) : f32 / [...] / [1,84,8400] [Step 7/11] Loading the model to the device [ INFO ] Compile model took 829.12 ms [ INFO ] Start of compilation memory usage: Peak 419568 KB [ INFO ] End of compilation memory usage: Peak 2907144 KB [ INFO ] Compile model ram used 2487576 KB [Step 8/11] Querying optimal runtime parameters [ INFO ] Model: [ INFO ] NETWORK_NAME: Model0 [ INFO ] EXECUTION_DEVICES: ['CPU'] [ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT [ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12 [ INFO ] MULTI_DEVICE_PRIORITIES: CPU [ INFO ] CPU: [ INFO ] CPU_DENORMALS_OPTIMIZATION: False [ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0 [ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32 [ INFO ] ENABLE_CPU_PINNING: True [ INFO ] ENABLE_CPU_RESERVATION: False [ INFO ] ENABLE_HYPER_THREADING: True [ INFO ] EXECUTION_DEVICES: ['CPU'] [ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE [ INFO ] INFERENCE_NUM_THREADS: 36 [ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'> [ INFO ] KEY_CACHE_GROUP_SIZE: 0 [ INFO ] KEY_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] KV_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] LOG_LEVEL: Level.NO [ INFO ] MODEL_DISTRIBUTION_POLICY: set() [ INFO ] NETWORK_NAME: Model0 [ INFO ] NUM_STREAMS: 12 [ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12 [ INFO ] PERFORMANCE_HINT: THROUGHPUT [ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0 [ INFO ] PERF_COUNT: NO [ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE [ INFO ] VALUE_CACHE_GROUP_SIZE: 0 [ INFO ] VALUE_CACHE_PRECISION: <Type: 'uint8_t'> [ INFO ] MODEL_PRIORITY: Priority.MEDIUM [ INFO ] LOADED_FROM_CACHE: False [ INFO ] PERF_COUNT: False [Step 9/11] Creating infer requests and preparing input tensors [ WARNING ] No input files were given for input 'x'!. This input will be filled with random values! [ INFO ] Fill input 'x' with random values [Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration) [ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop). [ INFO ] First inference took 30.79 ms [Step 11/11] Dumping statistics report [ INFO ] Execution Devices:['CPU'] [ INFO ] Count: 2640 iterations [ INFO ] Duration: 15090.58 ms [ INFO ] Latency: [ INFO ] Median: 67.56 ms [ INFO ] Average: 68.29 ms [ INFO ] Min: 34.05 ms [ INFO ] Max: 137.71 ms [ INFO ] Throughput: 174.94 FPS
This section contains suggestions on how to additionally improve the performance of your application using OpenVINO.
The key advantage of the Async API is that when a device is busy with inference, the application can perform other tasks in parallel (for example, populating inputs or scheduling other requests) rather than wait for the current inference to complete first. To understand how to perform async inference using openvino, refer to Async API tutorial
import collections
import time
from IPython import display
import cv2
import numpy as np
# Main processing function to run object detection.
def run_object_detection(
source=0,
flip=False,
use_popup=False,
skip_first_frames=0,
model=det_model,
device=device.value,
video_width: int = None, # if not set the original size is used
):
player = None
try:
# Create a video player to play with target fps.
player = VideoPlayer(source=source, flip=flip, fps=30, skip_first_frames=skip_first_frames)
# Start capturing.
player.start()
if use_popup:
title = "Press ESC to Exit"
cv2.namedWindow(winname=title, flags=cv2.WINDOW_GUI_NORMAL | cv2.WINDOW_AUTOSIZE)
processing_times = collections.deque()
while True:
# Grab the frame.
frame = player.next()
if frame is None:
print("Source ended")
break
if video_width:
# If the frame is larger than video_width, reduce size to improve the performance.
# If more, increase size for better demo expirience.
scale = video_width / max(frame.shape)
frame = cv2.resize(
src=frame,
dsize=None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_AREA,
)
# Get the results.
input_image = np.array(frame)
start_time = time.time()
detections = det_model(input_image, verbose=False, device=f"intel:{device.lower()}")
stop_time = time.time()
frame = detections[0].plot()
processing_times.append(stop_time - start_time)
# Use processing times from last 200 frames.
if len(processing_times) > 200:
processing_times.popleft()
_, f_width = frame.shape[:2]
# Mean processing time [ms].
processing_time = np.mean(processing_times) * 1000
fps = 1000 / processing_time
cv2.putText(
img=frame,
text=f"Inference time: {processing_time:.1f}ms ({fps:.1f} FPS)",
org=(20, 40),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=f_width / 1000,
color=(0, 0, 255),
thickness=1,
lineType=cv2.LINE_AA,
)
# Use this workaround if there is flickering.
if use_popup:
cv2.imshow(winname=title, mat=frame)
key = cv2.waitKey(1)
# escape = 27
if key == 27:
break
else:
# Encode numpy array to jpg.
_, encoded_img = cv2.imencode(ext=".jpg", img=frame, params=[cv2.IMWRITE_JPEG_QUALITY, 100])
# Create an IPython image.
i = display.Image(data=encoded_img)
# Display the image in this notebook.
display.clear_output(wait=True)
display.display(i)
# ctrl-c
except KeyboardInterrupt:
print("Interrupted")
# any different error
except RuntimeError as e:
print(e)
finally:
if player is not None:
# Stop capturing.
player.stop()
if use_popup:
cv2.destroyAllWindows()
Use a webcam as the video input. By default, the primary webcam is set with source=0
. If you have multiple webcams, each one will be assigned a consecutive number starting at 0. Set flip=True
when using a front-facing camera. Some web browsers, especially Mozilla Firefox, may cause flickering. If you experience flickering, set use_popup=True
.
NOTE: To use this notebook with a webcam, you need to run the notebook on a computer with a webcam. If you run the notebook on a remote server (for example, in Binder or Google Colab service), the webcam will not work. By default, the lower cell will run model inference on a video file. If you want to try live inference on your webcam set
WEBCAM_INFERENCE = True
Run the object detection:
WEBCAM_INFERENCE = False
if WEBCAM_INFERENCE:
VIDEO_SOURCE = 0 # Webcam
else:
VIDEO_SOURCE = "people.mp4"
if not Path(VIDEO_SOURCE).exists():
download_file(
"https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/video/people.mp4",
"people.mp4",
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
run_object_detection(
source=VIDEO_SOURCE,
flip=True,
use_popup=False,
model=det_model,
device=device.value,
# video_width=1280,
)
Source ended