This notebook shows how to apply different post-training quantization approaches such as static and dynamic quantization using ONNX Runtime, for any tasks of the GLUE benchmark. This is made possible thanks to 🤗 Optimum, an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to train and run models on targeted hardwares.
If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets and 🤗 Optimum. Uncomment the following cell and run it.
#! pip install datasets transformers optimum[onnxruntime]
Make sure your version of 🤗 Optimum is at least 1.1.0 since the functionality was introduced in that version:
from optimum.version import __version__
print(__version__)
1.1.0
The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:
We will see how to apply post-training static quantization on a DistilBERT model fine-tuned on the SST-2 task:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
task = "sst2"
model_checkpoint = "textattack/bert-base-uncased-SST-2"
We will use the 🤗 Datasets library to download the dataset and get the metric we need to use for evaluation. This can be easily done with the functions load_dataset
and load_metric
.
from datasets import load_dataset, load_metric
load_dataset
will cache the dataset to avoid downloading it again the next time you run this cell.
actual_task = "mnli" if task == "mnli-mm" else task
validation_split = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
eval_dataset = load_dataset("glue", actual_task, split=validation_split)
metric = load_metric("glue", actual_task)
Reusing dataset glue (/home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Note that load_metric
has loaded the proper metric associated to your task, which is:
so the metric object only computes the one(s) needed for your task.
To preprocess our dataset, we will need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mnli-mm": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
We can then write the function that will preprocess our samples. We just feed them to the tokenizer
with the argument truncation=True
. This will ensure that an input longer than what the model selected can handle will be truncated to the maximum length accepted by the model.
sentence1_key, sentence2_key = task_to_keys[task]
def preprocess_function(examples, tokenizer):
args = (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
return tokenizer(*args, padding="max_length", max_length=128, truncation=True)
We can set our quantization_approach
to either dynamic
or static
in order to apply respectively dynamic and static quantization.
QUANTIZATION_APPROACH = ["dynamic", "static"]
quantization_approach = "static"
First, let's create the output directory where the resulting quantized model will be saved.
import os
model_name = model_checkpoint.split("/")[-1]
output_dir = f"{model_name}-{quantization_approach}-quantization"
os.makedirs(output_dir, exist_ok=True)
model_path = os.path.join(output_dir, "model.onnx")
q8_model_path = os.path.join(output_dir, "model-quantized.onnx")
We will use the 🤗 Optimum library to instantiate an ORTQuantizer
, which will take care of the quantization process. To instantiate an ORTQuantizer
, we need to specify the model to quantize as well as the feature
which corresponds to the type of task that we wish to quantize the model for.
from optimum.onnxruntime.quantization import ORTQuantizer
quantizer = ORTQuantizer.from_pretrained(model_checkpoint, feature="sequence-classification")
We also need to create an QuantizationConfig
instance, which is the configuration handling the ONNX Runtime quantization related parameters.
per_channel
to False
in order to apply per-tensor quantization on the weights. As opposed to per-channel quantization, which introduces one set of quantization parameters per channel, per-tensor quantization means that there will be one set of quantization parameters per tensor.num_calibration_samples
to use for the calibration step resulting from static quantization to 40
.operators_to_quantize
is used to specify the types of operators to quantize, here we want to quantize all the network's fully connected and embedding layers.from optimum.onnxruntime.configuration import QuantizationConfig, AutoCalibrationConfig
from optimum.onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType
per_channel = False
num_calibration_samples = 40
operators_to_quantize = ["MatMul", "Add", "Gather"]
apply_static_quantization = quantization_approach == "static"
qconfig = QuantizationConfig(
is_static=apply_static_quantization,
format=QuantFormat.QDQ if apply_static_quantization else QuantFormat.QOperator,
mode=QuantizationMode.QLinearOps if apply_static_quantization else QuantizationMode.IntegerOps,
activations_dtype=QuantType.QInt8 if apply_static_quantization else QuantType.QUInt8,
weights_dtype=QuantType.QInt8,
per_channel=per_channel,
operators_to_quantize=operators_to_quantize,
)
When applying static quantization, we need to perform a calibration step where the activations quantization ranges are computed. This additionnal step should only be performed in the case of static quantization and not for dynamic quantization.
Because the quantization of certain nodes often results in degradation in accuracy, we create an instance of QuantizationPreprocessor
to determine the nodes to exclude when applying static quantization.
from functools import partial
from optimum.onnxruntime.preprocessors import QuantizationPreprocessor
from optimum.onnxruntime.preprocessors.passes import (
ExcludeGeLUNodes,
ExcludeLayerNormNodes,
ExcludeNodeAfter,
ExcludeNodeFollowedBy,
)
ranges = None
quantization_preprocessor = None
if apply_static_quantization:
# Create the calibration dataset used for the calibration step
calibration_dataset = quantizer.get_calibration_dataset(
"glue",
dataset_config_name=actual_task,
preprocess_function=partial(preprocess_function, tokenizer=quantizer.tokenizer),
num_samples=num_calibration_samples,
dataset_split="train",
)
calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)
# Perform the calibration step: computes the activations quantization ranges
ranges = quantizer.fit(
dataset=calibration_dataset,
calibration_config=calibration_config,
onnx_model_path=model_path,
)
quantization_preprocessor = QuantizationPreprocessor(model_path)
# Exclude the nodes constituting LayerNorm
quantization_preprocessor.register_pass(ExcludeLayerNormNodes())
# Exclude the nodes constituting GELU
quantization_preprocessor.register_pass(ExcludeGeLUNodes())
# Exclude the residual connection Add nodes
quantization_preprocessor.register_pass(ExcludeNodeAfter("Add", "Add"))
# Exclude the Add nodes following the Gather operator
quantization_preprocessor.register_pass(ExcludeNodeAfter("Gather", "Add"))
# Exclude the Add nodes followed by the Softmax operator
quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax"))
Reusing dataset glue (/home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad) Loading cached shuffled indices for dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9e7c65aff5c29f4a.arrow Loading cached processed dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6eddf1eb07f9b2a9.arrow
Finally, we export the quantized model.
quantizer.export(
onnx_model_path=model_path,
onnx_quantized_model_output_path=q8_model_path,
calibration_tensors_range=ranges,
quantization_config=qconfig,
preprocessor=quantization_preprocessor,
)
PosixPath('bert-base-uncased-SST-2-static-quantization/model-quantized.onnx')
To evaluate our resulting quantized model we need to define how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric
we loaded earlier, the only preprocessing we have to do is to take the argmax of our predicted logits (our just squeeze the last axis in the case of STS-B).
The metric chosen to evaluate the quantized model's performance will be Matthews correlation coefficient (MCC) for CoLA, Pearson correlation coefficient (PCC) for STS-B and accuracy for any other tasks.
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != "stsb":
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:, 0]
return metric.compute(predictions=predictions, references=labels)
Then to apply the preprocessing on all the sentences (or pairs of sentences) of our validation dataset, we just use the map
method of our dataset
object that was earlier created. This will apply the preprocess_function
function on all the elements of our validation dataset.
eval_dataset = eval_dataset.map(partial(preprocess_function, tokenizer=quantizer.tokenizer), batched=True)
Loading cached processed dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-20e733ff351964bc.arrow
Finally, to estimate the drop in performance resulting from quantization, we are going to perform an evaluation step for both models (before and after quantization). In order to perform the latter, we will need to instantiate an ORTModel
and thus need:
onnx_config
associated to the model. This instance of OnnxConfig
describes how to export the model through the ONNX format.compute_metrics
that was defined previously.from optimum.onnxruntime import ORTModel
ort_model = ORTModel(model_path, quantizer._onnx_config, compute_metrics=compute_metrics, label_names=["label"])
model_output = ort_model.evaluation_loop(eval_dataset)
model_output.metrics
{'accuracy': 0.9243119266055045}
q8_ort_model = ORTModel(q8_model_path, quantizer._onnx_config, compute_metrics=compute_metrics, label_names=["label"])
q8_model_output = q8_ort_model.evaluation_loop(eval_dataset)
q8_model_output.metrics
{'accuracy': 0.9071100917431193}
Now let's compute the full-precision and the quantized model respective size in megabyte (MB) :
fp_model_size = os.path.getsize(model_path) / (1024*1024)
q_model_size = os.path.getsize(q8_model_path) / (1024*1024)
The reduction in the model size resulting from quantization is:
round(fp_model_size / q_model_size, 2)
3.92