This notebook is largely inspired by the summarization notebook of Transformers which takes PyTorch as backend for fine tuning.
Here you will use the ORTSeq2SeqTrainer
class in Optimum library and take ONNX Runtime as backend to accelerate the training.
In this notebook, we will walk through the fine-tuning of T5-small model in the 🤗 Transformers for a summarization task. We will use the XSum dataset (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries, and the training as well as inference will be done by leveraging ORTSeq2SeqTrainer
in Optimum!
Let's speed the training up!
Dependencies
To use ONNX Runtime for training, you need a machine with at least one NVIDIA GPU.
ONNX Runtime training module need to be properly installed before launching the notebook! Please follow the instruction in Optimum's documentation to set up your environment.
Check your GPU:
!nvidia-smi
If you're opening this Notebook on colab, you will probably need to install 🤗 Optimum, 🤗 Transformers, 🤗 Datasets and 🤗 evaluate. Uncomment the following cell and run it.
!pip install optimum transformers datasets evaluate rouge-score nltk tokenizers>=0.11.0
import nltk
nltk.download("punkt")
[Optional] If you want to share your model with the community and generate an inference API, there are a few more steps to follow.
First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:
from huggingface_hub import notebook_login
notebook_login()
Then you need to install Git-LFS. Uncomment the following instructions:
!apt install git-lfs
Make sure your version of Transformers is at least 4.15.0:
import transformers
print(transformers.__version__)
Setup
model_checkpoint = "t5-small"
task = "xsum"
metric_name = "rouge"
batch_size = 8
learning_rate=2e-5
weight_decay = 0.01
num_train_epochs = 1
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
from transformers.utils import send_example_telemetry
send_example_telemetry("summarization_notebook_ort", framework="none")
We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions load_dataset
.
from datasets import load_dataset
import evaluate
raw_datasets = load_dataset(task)
metric = evaluate.load(metric_name)
[Optional] To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.
import datasets
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=1):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
for column, typ in dataset.features.items():
if isinstance(typ, datasets.ClassLabel):
df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
show_random_elements(raw_datasets["train"])
The metric is an instance of datasets.Metric
:
metric
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)
Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer
which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.
To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained
method, which will ensure:
That vocabulary will be cached, so it's not downloaded again the next time we run the cell.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
To prepare the targets for our model, we need to tokenize them inside the as_target_tokenizer context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:
with tokenizer.as_target_tokenizer():
print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))
If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).
if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
prefix = "summarize: "
else:
prefix = ""
We can then write the function that will preprocess our samples. We just feed them to the tokenizer
with the argument truncation=True
. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.
max_input_length = 1024
max_target_length = 128
def preprocess_function(examples):
inputs = [prefix + doc for doc in examples["document"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
To apply this function on all the pairs of sentences in our dataset, we just use the map
method of our dataset
object we created earlier. This will apply the function on all the elements of all the splits in dataset
, so our training, validation and testing data will be preprocessed in one single command.
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass load_from_cache_file=False
in the call to map
to not use the cached files and force the preprocessing to be applied again.
Note that we passed batched=True
to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.
Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the AutoModelForSeq2SeqLM
class to fist load the PyTorch model. Like with the tokenizer, the from_pretrained
method will download and cache the model for us.
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
Note that we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.
To instantiate a ORTSeq2SeqTrainer
, we will need to define three more things. The most important is the ORTSeq2SeqTrainingArguments
, which is a class that contains all the attributes to customize the training. You can also use Seq2SeqTrainingArguments
in Transformers, but ORTSeq2SeqTrainingArguments
enables more optimized features of ONNX Runtime. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:
model_name = model_checkpoint.split("/")[-1]
args = ORTSeq2SeqTrainingArguments(
f"{model_name}-finetuned-xsum",
eval_strategy = "epoch",
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
save_total_limit=3,
num_train_epochs=num_train_epochs,
predict_with_generate=True,
optim="adamw_ort_fused",
# push_to_hub=True,
)
Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the batch_size
defined at the top of the cell and customize the weight decay. Since the ORTSeq2SeqTrainer
will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the predict_with_generate
option (to properly generate summaries) and activate mixed precision training (to go a bit faster).
The last argument to setup everything so we can push the model to the Hub regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the hub_model_id argument to set the repo name (it needs to be the full name, including your namespace: for instance "optimum/t5-large-finetuned-xsum"
).
Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=tokenizer.pad_token_id,
pad_to_multiple_of=8 if args.fp16 else None,
)
The last thing to define for our ORTSeq2SeqTrainer
is how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric
we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:
import nltk
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Rouge expects a newline after each sentence
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
# Add mean generated length
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
result["gen_len"] = np.mean(prediction_lens)
return {k: round(v, 4) for k, v in result.items()}
Then we just need to pass all of this along with our datasets to the ORTSeq2SeqTrainer
:
trainer = ORTSeq2SeqTrainer(
model=model,
args=args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if args.predict_with_generate else None,
feature="seq2seq-lm",
)
We can now finetune our model by just calling the train
method:
trainer.train()
You can now upload the result of the training to the Hub, just execute this instruction:
trainer.push_to_hub()
You will also be able to save your fine-tuned model as PyTorch or ONNX model in the output_dir
that you set in ORTSeq2SeqTrainer
:
trainer.save_model()
Evaluate the performance of the model that you just fine-tuned with the validation dataset that you've passed to ORTSeq2SeqTrainer
by just calling the evaluate
method.
If you set inference_with_ort=True
, the inference will be done with ONNX Runtime backend. Otherwise, the inference will take PyTorch as backend.
trainer.evaluate(inference_with_ort=True)
Now check your trained ONNX model with Netron, and you might notice that the computation graph is under optimizatiom. Want to accelerate even more?
Check the graph optimizers and quantizers of Optimum🤗!