#!/usr/bin/env python
# coding: utf-8
#
How to export 🤗 Transformers Models to ONNX ?
# [ONNX](http://onnx.ai/) is open format for machine learning models. It allows to save your neural network's computation graph in a framework agnostic way, which might be particulary helpful when deploying deep learning models.
#
# Indeed, businesses might have other requirements _(languages, hardware, ...)_ for which the training framework might not be the best suited in inference scenarios. In that context, having a representation of the actual computation graph that can be shared accross various business units and logics across an organization might be a desirable component.
#
# Along with the serialization format, ONNX also provides a runtime library which allows efficient and hardware specific execution of the ONNX graph. This is done through the [onnxruntime](https://microsoft.github.io/onnxruntime/) project and already includes collaborations with many hardware vendors to seamlessly deploy models on various platforms.
#
# Through this notebook we'll walk you through the process to convert a PyTorch or TensorFlow transformers model to the [ONNX](http://onnx.ai/) and leverage [onnxruntime](https://microsoft.github.io/onnxruntime/) to run inference tasks on models from 🤗 __transformers__
# ## Exporting 🤗 transformers model to ONNX
#
# ---
#
# Exporting models _(either PyTorch or TensorFlow)_ is easily achieved through the conversion tool provided as part of 🤗 __transformers__ repository.
#
# Under the hood the process is sensibly the following:
#
# 1. Allocate the model from transformers (**PyTorch or TensorFlow**)
# 2. Forward dummy inputs through the model this way **ONNX** can record the set of operations executed
# 3. Optionally define dynamic axes on input and output tensors
# 4. Save the graph along with the network parameters
# In[22]:
import sys
get_ipython().system('{sys.executable} -m pip install --upgrade git+https://github.com/huggingface/transformers')
get_ipython().system('{sys.executable} -m pip install --upgrade torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html')
get_ipython().system('{sys.executable} -m pip install --upgrade onnxruntime==1.4.0')
get_ipython().system('{sys.executable} -m pip install -i https://test.pypi.org/simple/ ort-nightly')
get_ipython().system('{sys.executable} -m pip install --upgrade onnxruntime-tools')
# In[23]:
get_ipython().system('rm -rf onnx/')
from pathlib import Path
from transformers.convert_graph_to_onnx import convert
# Handles all the above steps for you
convert(framework="pt", model="bert-base-cased", output=Path("onnx/bert-base-cased.onnx"), opset=11)
# Tensorflow
# convert(framework="tf", model="bert-base-cased", output="onnx/bert-base-cased.onnx", opset=11)
# ## How to leverage runtime for inference over an ONNX graph
#
# ---
#
# As mentionned in the introduction, **ONNX** is a serialization format and many side projects can load the saved graph and run the actual computations from it. Here, we'll focus on the official [onnxruntime](https://microsoft.github.io/onnxruntime/). The runtime is implemented in C++ for performance reasons and provides API/Bindings for C++, C, C#, Java and Python.
#
# In the case of this notebook, we will use the Python API to highlight how to load a serialized **ONNX** graph and run inference workload on various backends through **onnxruntime**.
#
# **onnxruntime** is available on pypi:
#
# - onnxruntime: ONNX + MLAS (Microsoft Linear Algebra Subprograms)
# - onnxruntime-gpu: ONNX + MLAS + CUDA
#
# In[24]:
get_ipython().system('pip install transformers onnxruntime-gpu onnx psutil matplotlib')
# ## Preparing for an Inference Session
#
# ---
#
# Inference is done using a specific backend definition which turns on hardware specific optimizations of the graph.
#
# Optimizations are basically of three kinds:
#
# - **Constant Folding**: Convert static variables to constants in the graph
# - **Deadcode Elimination**: Remove nodes never accessed in the graph
# - **Operator Fusing**: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)
#
# ONNX Runtime automatically applies most optimizations by setting specific `SessionOptions`.
#
# Note:Some of the latest optimizations that are not yet integrated into ONNX Runtime are available in [optimization script](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) that tunes models for the best performance.
# In[25]:
# # An optional step unless
# # you want to get a model with mixed precision for perf accelartion on newer GPU
# # or you are working with Tensorflow(tf.keras) models or pytorch models other than bert
# !pip install onnxruntime-tools
# from onnxruntime_tools import optimizer
# # Mixed precision conversion for bert-base-cased model converted from Pytorch
# optimized_model = optimizer.optimize_model("bert-base-cased.onnx", model_type='bert', num_heads=12, hidden_size=768)
# optimized_model.convert_model_float32_to_float16()
# optimized_model.save_model_to_file("bert-base-cased.onnx")
# # optimizations for bert-base-cased model converted from Tensorflow(tf.keras)
# optimized_model = optimizer.optimize_model("bert-base-cased.onnx", model_type='bert_keras', num_heads=12, hidden_size=768)
# optimized_model.save_model_to_file("bert-base-cased.onnx")
# optimize transformer-based models with onnxruntime-tools
from onnxruntime_tools import optimizer
from onnxruntime_tools.transformers.onnx_model_bert import BertOptimizationOptions
# disable embedding layer norm optimization for better model size reduction
opt_options = BertOptimizationOptions('bert')
opt_options.enable_embed_layer_norm = False
opt_model = optimizer.optimize_model(
'onnx/bert-base-cased.onnx',
'bert',
num_heads=12,
hidden_size=768,
optimization_options=opt_options)
opt_model.save_model_to_file('bert.opt.onnx')
# In[26]:
from os import environ
from psutil import cpu_count
# Constants from the performance optimization available in onnxruntime
# It needs to be done before importing onnxruntime
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True))
environ["OMP_WAIT_POLICY"] = 'ACTIVE'
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
# In[27]:
from contextlib import contextmanager
from dataclasses import dataclass
from time import time
from tqdm import trange
def create_model_for_provider(model_path: str, provider: str) -> InferenceSession:
assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
# Few properties that might have an impact on performances (provided by MS)
options = SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# Load the model as a graph and prepare the CPU backend
session = InferenceSession(model_path, options, providers=[provider])
session.disable_fallback()
return session
@contextmanager
def track_infer_time(buffer: [int]):
start = time()
yield
end = time()
buffer.append(end - start)
@dataclass
class OnnxInferenceResult:
model_inference_time: [int]
optimized_model_path: str
# ## Forwarding through our optimized ONNX model running on CPU
#
# ---
#
# When the model is loaded for inference over a specific provider, for instance **CPUExecutionProvider** as above, an optimized graph can be saved. This graph will might include various optimizations, and you might be able to see some **higher-level** operations in the graph _(through [Netron](https://github.com/lutzroeder/Netron) for instance)_ such as:
# - **EmbedLayerNormalization**
# - **Attention**
# - **FastGeLU**
#
# These operations are an example of the kind of optimization **onnxruntime** is doing, for instance here gathering multiple operations into bigger one _(Operator Fusing)_.
# In[28]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
cpu_model = create_model_for_provider("onnx/bert-base-cased.onnx", "CPUExecutionProvider")
# Inputs are provided through numpy array
model_inputs = tokenizer("My name is Bert", return_tensors="pt")
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()}
# Run the model (None = get all the outputs)
sequence, pooled = cpu_model.run(None, inputs_onnx)
# Print information about outputs
print(f"Sequence output: {sequence.shape}, Pooled output: {pooled.shape}")
# # Benchmarking PyTorch model
#
# _Note: PyTorch model benchmark is run on CPU_
# In[29]:
from transformers import BertModel
PROVIDERS = {
("cpu", "PyTorch CPU"),
# Uncomment this line to enable GPU benchmarking
# ("cuda:0", "PyTorch GPU")
}
results = {}
for device, label in PROVIDERS:
# Move inputs to the correct device
model_inputs_on_device = {
arg_name: tensor.to(device)
for arg_name, tensor in model_inputs.items()
}
# Add PyTorch to the providers
model_pt = BertModel.from_pretrained("bert-base-cased").to(device)
for _ in trange(10, desc="Warming up"):
model_pt(**model_inputs_on_device)
# Compute
time_buffer = []
for _ in trange(100, desc=f"Tracking inference time on PyTorch"):
with track_infer_time(time_buffer):
model_pt(**model_inputs_on_device)
# Store the result
results[label] = OnnxInferenceResult(
time_buffer,
None
)
# ## Benchmarking PyTorch & ONNX on CPU
#
# _**Disclamer: results may vary from the actual hardware used to run the model**_
# In[30]:
PROVIDERS = {
("CPUExecutionProvider", "ONNX CPU"),
# Uncomment this line to enable GPU benchmarking
# ("CUDAExecutionProvider", "ONNX GPU")
}
for provider, label in PROVIDERS:
# Create the model with the specified provider
model = create_model_for_provider("onnx/bert-base-cased.onnx", provider)
# Keep track of the inference time
time_buffer = []
# Warm up the model
model.run(None, inputs_onnx)
# Compute
for _ in trange(100, desc=f"Tracking inference time on {provider}"):
with track_infer_time(time_buffer):
model.run(None, inputs_onnx)
# Store the result
results[label] = OnnxInferenceResult(
time_buffer,
model.get_session_options().optimized_model_filepath
)
# In[31]:
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
# Compute average inference time + std
time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()}
time_results_std = np.std([v.model_inference_time for v in results.values()]) * 1000
plt.rcdefaults()
fig, ax = plt.subplots(figsize=(16, 12))
ax.set_ylabel("Avg Inference time (ms)")
ax.set_title("Average inference time (ms) for each provider")
ax.bar(time_results.keys(), time_results.values(), yerr=time_results_std)
plt.show()
# # Quantization support from transformers
#
# Quantization enables the use of integers (_instead of floatting point_) arithmetic to run neural networks models faster. From a high-level point of view, quantization works as mapping the float32 ranges of values as int8 with the less loss in the performances of the model.
#
# Hugging Face provides a conversion tool as part of the transformers repository to easily export quantized models to ONNX Runtime. For more information, please refer to the following:
#
# - [Hugging Face Documentation on ONNX Runtime quantization supports](https://huggingface.co/transformers/master/serialization.html#quantization)
# - [Intel's Explanation of Quantization](https://nervanasystems.github.io/distiller/quantization.html)
#
# With this method, the accuracy of the model remains at the same level than the full-precision model. If you want to see benchmarks on model performances, we recommand reading the [ONNX Runtime notebook](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/quantization/notebooks/Bert-GLUE_OnnxRuntime_quantization.ipynb) on the subject.
# # Benchmarking PyTorch quantized model
# In[32]:
import torch
# Quantize
model_pt_quantized = torch.quantization.quantize_dynamic(
model_pt.to("cpu"), {torch.nn.Linear}, dtype=torch.qint8
)
# Warm up
model_pt_quantized(**model_inputs)
# Benchmark PyTorch quantized model
time_buffer = []
for _ in trange(100):
with track_infer_time(time_buffer):
model_pt_quantized(**model_inputs)
results["PyTorch CPU Quantized"] = OnnxInferenceResult(
time_buffer,
None
)
# # Benchmarking ONNX quantized model
# In[33]:
from transformers.convert_graph_to_onnx import quantize
# Transformers allow you to easily convert float32 model to quantized int8 with ONNX Runtime
quantized_model_path = quantize(Path("bert.opt.onnx"))
# Then you just have to load through ONNX runtime as you would normally do
quantized_model = create_model_for_provider(quantized_model_path.as_posix(), "CPUExecutionProvider")
# Warm up the overall model to have a fair comparaison
outputs = quantized_model.run(None, inputs_onnx)
# Evaluate performances
time_buffer = []
for _ in trange(100, desc=f"Tracking inference time on CPUExecutionProvider with quantized model"):
with track_infer_time(time_buffer):
outputs = quantized_model.run(None, inputs_onnx)
# Store the result
results["ONNX CPU Quantized"] = OnnxInferenceResult(
time_buffer,
quantized_model_path
)
# ## Show the inference performance of each providers
# In[34]:
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
# Compute average inference time + std
time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()}
time_results_std = np.std([v.model_inference_time for v in results.values()]) * 1000
plt.rcdefaults()
fig, ax = plt.subplots(figsize=(16, 12))
ax.set_ylabel("Avg Inference time (ms)")
ax.set_title("Average inference time (ms) for each provider")
ax.bar(time_results.keys(), time_results.values(), yerr=time_results_std)
plt.show()