This notebook shows how to apply different quantization approaches such as dynamic, static and aware training quantization, using the Intel Neural Compressor (INC) library, for any tasks of the GLUE benchmark. This is made possible thanks to 🤗 Optimum, an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to train and run models on targeted hardwares.
If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets and 🤗 Optimum. Uncomment the following cell and run it.
#! pip install datasets transformers optimum[intel]
Make sure your version of 🤗 Optimum is at least 1.2.3 since the functionality was introduced in that version:
from optimum.intel.version import __version__
print(__version__)
1.2.3
Note that quantization is currently only supported for CPUs, so we will not be utilizing GPUs / CUDA in this notebook.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:
We will see how to apply post-training static quantization on a DistilBERT model fine-tuned on the SST-2 task:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
task = "sst2"
model_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
batch_size = 16
max_train_samples = 100
We can set our quantization_approach
to either dynamic
, static
or aware_training
in order to apply respectively dynamic, static and aware training quantization.
Quantization will be applied on the embeddings, and on the linear layers as well as on their corresponding input activations.
SUPPORTED_QUANTIZATION_APPROACH = ["dynamic", "static", "aware_training"]
quantization_approach = "static"
We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our quantized model to the baseline). This can be easily done with the functions load_dataset
and load_metric
.
from datasets import load_dataset, load_metric
Apart from mnli-mm
being a special code, we can directly pass our task name to those functions. load_dataset
will cache the dataset to avoid downloading it again the next time you run this cell.
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric("glue", actual_task)
2022-06-14 15:28:50 [WARNING] Reusing dataset glue (/home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
0%| | 0/3 [00:00<?, ?it/s]
Note that load_metric
has loaded the proper metric associated to your task, which is:
so the metric object only computes the one(s) needed for your task.
Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer
which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.
To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained
method, which will ensure that:
That vocabulary will be cached, so it's not downloaded again the next time we run the cell.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mnli-mm": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
We can double check it does work on our current dataset:
sentence1_key, sentence2_key = task_to_keys[task]
if sentence2_key is None:
print(f"Sentence: {dataset['train'][0][sentence1_key]}")
else:
print(f"Sentence 1: {dataset['train'][0][sentence1_key]}")
print(f"Sentence 2: {dataset['train'][0][sentence2_key]}")
Sentence: hide new secretions from the parental units
We can then write the function that will preprocess our samples. We just feed them to the tokenizer
with the argument truncation=True
. This will ensure that an input longer than what the model selected can handle will be truncated to the maximum length accepted by the model.
max_seq_length = min(128, tokenizer.model_max_length)
padding = "max_length"
def preprocess_function(examples):
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
return tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the map
method of our dataset
object we created earlier. This will apply the function on all the elements of all the splits in dataset
, so our training, validation and testing data will be preprocessed in one single command.
encoded_dataset = dataset.map(preprocess_function, batched=True)
2022-06-14 15:28:55 [WARNING] Loading cached processed dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fa763904b981dcc5.arrow 2022-06-14 15:28:55 [WARNING] Loading cached processed dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fdce79612ea077d3.arrow 2022-06-14 15:28:55 [WARNING] Loading cached processed dataset at /home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c3c52d2d662a80eb.arrow
Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass load_from_cache_file=False
in the call to map
to not use the cached files and force the preprocessing to be applied again.
Note that we passed batched=True
to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.
Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the AutoModelForSequenceClassification
class. Like with the tokenizer, the from_pretrained
method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which is always 2, except for STS-B which is a regression problem and MNLI where we have 3 labels):
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
fp_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
To instantiate a Trainer
, we will need to define two more things. The most important is the TrainingArguments
, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:
model_name = model_checkpoint.split("/")[-1]
output = f"{model_name}-finetuned-{task}"
args = TrainingArguments(
output,
evaluation_strategy = "epoch",
save_strategy = "epoch",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
dataloader_drop_last=False if quantization_approach == "dynamic" else True,
)
The last thing to define for our Trainer
is how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric
we loaded earlier, the only preprocessing we have to do is to take the argmax of our predicted logits (our just squeeze the last axis in the case of STS-B):
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != "stsb":
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:, 0]
return metric.compute(predictions=predictions, references=labels)
Then we just need to pass all of this along with our datasets to the Trainer
:
from transformers import default_data_collator
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
train_dataset = encoded_dataset["train"].select(range(max_train_samples))
trainer = Trainer(
model=fp_model,
args=args,
train_dataset=train_dataset if quantization_approach != "dynamic" else None,
eval_dataset=encoded_dataset[validation_key],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
)
In the case where we want to apply quantization aware training, we need to pass to the Intel Neural Compressor (INC) library a training function. Note that, as we are using a Trainer
, we need to set its model
attribute to the quantized model resulting from the INC library.
def train_func(model):
trainer.model_wrapped = model
trainer.model = model
train_result = trainer.train()
metrics = train_result.metrics
trainer.save_model()
trainer.save_metrics("train", metrics)
trainer.save_state()
In order to evaluate the model's performance before and after quantization, we need to define an evaluation function. The metric chosen to evaluate the drop in performance resulting from quantization will be Matthews correlation coefficient (MCC) for CoLA, Pearson correlation coefficient (PCC) for STS-B and accuracy for any other tasks.
metric_name = "eval_" + ("pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy")
def eval_func(model):
trainer.model = model
metrics = trainer.evaluate()
return metrics.get(metric_name)
fp_model_result = eval_func(fp_model)
print(f"The full-precision model has an {metric_name} of {round(fp_model_result * 100, 2)}.")
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Evaluation ***** Num examples = 872 Batch size = 16
The full-precision model has an eval_accuracy of 91.09.
We instantiate IncQuantizationConfig
using a configuration file containing all the informations related to quantization and tuning objective. We can set the quantization approach as well as the accuracy target, currently tolerating a 0.02 relative performance drop when compared to our baseline which is the full-precision model.
from optimum.intel.neural_compressor import IncQuantizationConfig, IncQuantizationMode
config = "echarlaix/bert-base-uncased-sst2-static-quant-test"
q8_config = IncQuantizationConfig.from_pretrained(config, config_file_name="quantization.yml")
accuracy_criterion = 0.02
q8_config.set_config("tuning.accuracy_criterion.relative", accuracy_criterion)
q8_approach = getattr(IncQuantizationMode, quantization_approach.upper()).value
q8_config.set_config("quantization.approach", q8_approach)
For both static and aware training quantization, we use PyTorch FX Graph Mode Quantization.
if quantization_approach != "dynamic":
q8_config.set_config("model.framework", "pytorch_fx")
To instantiate an IncQuantizer
, we need a configuration containing all the informations relative to quantization and tuning (which can be either a path to a YAML file or an IncQuantizationConfig
object), the model to quantize and finally an evaluation function which will be used to evaluate the quantization impact and thus verify if it fits the tolerance defined by the user.
In the case of static quantization, our IncQuantizer
will also need a calibration dataloader in order to perform the calibration step.
In the case of aware training quantization, it will need a training function, the latter will be used to perform the training will applying quantization.
We can now instantiate our IncOptimizer
which will take care of the quantization process.
from optimum.intel.neural_compressor import IncQuantizer, IncOptimizer
quantizer = IncQuantizer(
config_path_or_obj=q8_config,
eval_func=eval_func,
train_func=train_func if quantization_approach == "aware_training" else None,
calib_dataloader=trainer.get_train_dataloader() if quantization_approach == "static" else None,
)
optimizer = IncOptimizer(fp_model, quantizer=quantizer)
q_model = optimizer.fit()
The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. 2022-06-14 15:29:34 [INFO] Start sequential pipeline execution. 2022-06-14 15:29:34 [INFO] The 0th step being executing is QUANTIZATION. 2022-06-14 15:29:34 [INFO] Pass query framework capability elapsed time: 166.45 ms 2022-06-14 15:29:34 [INFO] Get FP32 model baseline. The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Evaluation ***** Num examples = 872 Batch size = 16 2022-06-14 15:30:10 [INFO] Save tuning history to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/./history.snapshot. 2022-06-14 15:30:10 [INFO] FP32 baseline is: [Accuracy: 0.9109, Duration (seconds): 36.2036] /home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/qconfig.py:88: UserWarning: QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead") 2022-06-14 15:30:10 [INFO] Fx trace of the entire model failed, We will conduct auto quantization /home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/observer.py:177: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch. warnings.warn( 2022-06-14 15:30:10 [WARNING] Please note that calibration sampling size 100 isn't divisible exactly by batch size 16. So the real sampling size is 112. /home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/nn/quantized/_reference/modules/linear.py:41: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) /home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/nn/quantized/_reference/modules/linear.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor( 2022-06-14 15:30:18 [INFO] |*********Mixed Precision Statistics********| 2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+ 2022-06-14 15:30:18 [INFO] | Op Type | Total | INT8 | FP32 | 2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+ 2022-06-14 15:30:18 [INFO] | Embedding | 2 | 2 | 0 | 2022-06-14 15:30:18 [INFO] | LayerNorm | 13 | 0 | 13 | 2022-06-14 15:30:18 [INFO] | quantize_per_tensor | 38 | 38 | 0 | 2022-06-14 15:30:18 [INFO] | Linear | 38 | 38 | 0 | 2022-06-14 15:30:18 [INFO] | dequantize | 38 | 38 | 0 | 2022-06-14 15:30:18 [INFO] | Dropout | 6 | 0 | 6 | 2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+ 2022-06-14 15:30:18 [INFO] Pass quantize model elapsed time: 8274.07 ms The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Evaluation ***** Num examples = 872 Batch size = 16 2022-06-14 15:30:43 [INFO] Tune 1 result is: [Accuracy (int8|fp32): 0.9005|0.9109, Duration (seconds) (int8|fp32): 24.7176|36.2036], Best tune result is: [Accuracy: 0.9005, Duration (seconds): 24.7176] 2022-06-14 15:30:43 [INFO] |**********************Tune Result Statistics**********************| 2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+ 2022-06-14 15:30:43 [INFO] | Info Type | Baseline | Tune 1 result | Best tune result | 2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+ 2022-06-14 15:30:43 [INFO] | Accuracy | 0.9109 | 0.9005 | 0.9005 | 2022-06-14 15:30:43 [INFO] | Duration (seconds) | 36.2036 | 24.7176 | 24.7176 | 2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+ 2022-06-14 15:30:43 [INFO] Save tuning history to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/./history.snapshot. 2022-06-14 15:30:43 [INFO] Specified timeout or max trials is reached! Found a quantized model which meet accuracy goal. Exit. 2022-06-14 15:30:43 [INFO] Save deploy yaml to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/deploy.yaml
q_model_result = eval_func(q_model.model)
print(f"The resulting quantized model has an {metric_name} of {round(q_model_result * 100, 2)}.")
print(f"This results in a drop of {round((fp_model_result - q_model_result) * 100, 2)} in {metric_name} when compared to the full-precision model.")
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Evaluation ***** Num examples = 872 Batch size = 16
The resulting quantized model has an eval_accuracy of 90.05. This results in a drop of 1.04 in eval_accuracy when compared to the full-precision model.
import torch
def get_model_size(model):
torch.save(model.state_dict(), "tmp.pt")
model_size = os.path.getsize("tmp.pt") / (1024*1024)
os.remove("tmp.pt")
return round(model_size, 2)
fp_model_size = get_model_size(fp_model)
q_model_size = get_model_size(q_model.model)
print(f"The full-precision model size is {round(fp_model_size)} MB while the quantized model one is {round(q_model_size)} MB.")
print(f"The quantized model is {round(fp_model_size / q_model_size, 2)}x smaller than the full-precision one.")
The full-precision model size is 255 MB while the quantized model one is 65 MB. The quantized model is 3.93x smaller than the full-precision one.
We save the resulting quantized model as well as its configuration.
optimizer.save_pretrained(output)
Configuration saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json 2022-06-14 15:31:09 [INFO] Model weights saved to distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2
The previously saved config file containing all the informations relative to the model quantization is used to instantiate anIncOptimizedConfig
. We then load the model using IncQuantizedModelForSequenceClassification
.
from optimum.intel.neural_compressor.quantization import IncQuantizedModelForSequenceClassification
loaded_model = IncQuantizedModelForSequenceClassification.from_pretrained(output)
loaded_model.eval()
loaded_model_result = eval_func(loaded_model)
loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json Model config DistilBertConfig { "_name_or_path": "distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2", "activation": "gelu", "architectures": [ "DistilBertForSequenceClassification" ], "attention_dropout": 0.1, "dim": 768, "dropout": 0.1, "finetuning_task": "sst-2", "hidden_dim": 3072, "id2label": { "0": "NEGATIVE", "1": "POSITIVE" }, "initializer_range": 0.02, "label2id": { "NEGATIVE": 0, "POSITIVE": 1 }, "max_position_embeddings": 512, "model_type": "distilbert", "n_heads": 12, "n_layers": 6, "output_past": true, "pad_token_id": 0, "problem_type": "single_label_classification", "qa_dropout": 0.1, "seq_classif_dropout": 0.2, "sinusoidal_pos_embds": false, "tie_weights_": true, "torch_dtype": "int8", "transformers_version": "4.19.4", "vocab_size": 30522 } loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json Model config DistilBertConfig { "_name_or_path": "distilbert-base-uncased-finetuned-sst-2-english", "activation": "gelu", "architectures": [ "DistilBertForSequenceClassification" ], "attention_dropout": 0.1, "dim": 768, "dropout": 0.1, "finetuning_task": "sst-2", "hidden_dim": 3072, "id2label": { "0": "NEGATIVE", "1": "POSITIVE" }, "initializer_range": 0.02, "label2id": { "NEGATIVE": 0, "POSITIVE": 1 }, "max_position_embeddings": 512, "model_type": "distilbert", "n_heads": 12, "n_layers": 6, "output_past": true, "pad_token_id": 0, "problem_type": "single_label_classification", "qa_dropout": 0.1, "seq_classif_dropout": 0.2, "sinusoidal_pos_embds": false, "tie_weights_": true, "torch_dtype": "int8", "transformers_version": "4.19.4", "vocab_size": 30522 } loading weights file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/pytorch_model.bin All model checkpoint weights were used when initializing DistilBertForSequenceClassification. All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2. If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training. /home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/observer.py:1124: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point warnings.warn( The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Evaluation ***** Num examples = 872 Batch size = 16
if loaded_model_result == q_model_result:
print("The quantized model was successfully loaded.")
else:
print("The quantized model was NOT successfully loaded.")
The quantized model was successfully loaded.