If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it.
#! pip install transformers
#! pip install datasets
#! pip install huggingface_hub
If you're opening this notebook locally, make sure your environment has an install from the latest version of those libraries.
To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.
First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then run the following cell and input your token:
from huggingface_hub import notebook_login
notebook_login()
Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:
# !apt install git-lfs
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"
Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was only introduced in that version.
import transformers
print(transformers.__version__)
4.16.0.dev0
You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs here.
In this notebook, we'll see how to fine-tune one of the 🤗 Transformers model on a language modeling task. We will cover two types of language modeling tasks which are:
We will see how to easily load and preprocess the dataset for each one of those tasks, and how to use Keras to fine-tune a model on it.
A script version of this notebook you can directly run on a distributed environment or on TPU is available in our examples folder.
For each of those tasks, we will use the Wikitext 2 dataset as an example. You can load it very easily with the 🤗 Datasets library.
from datasets import load_dataset
datasets = load_dataset("wikitext", "wikitext-2-raw-v1")
Reusing dataset wikitext (/home/matt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
0%| | 0/3 [00:00<?, ?it/s]
You can replace the dataset above with any dataset hosted on the hub or use your own files. Just uncomment the following cell and replace the paths with your own input files:
# datasets = load_dataset("text", data_files={"train": path_to_train.txt, "validation": path_to_validation.txt}
You can also load datasets from a csv or a JSON file, see the full documentation for more information.
To access an actual element, you need to select a split first, then give an index:
datasets["train"][10]
{'text': ' The game \'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and " Battle Potentials " , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique " Masters Table " , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate " Direct Command " and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her " Valkyria Form " and become invincible , while Imca can target multiple enemy units with her heavy weapon . \n'}
To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(
dataset
), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset) - 1)
while pick in picks:
pick = random.randint(0, len(dataset) - 1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
for column, typ in dataset.features.items():
if isinstance(typ, ClassLabel):
df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
show_random_elements(datasets["train"])
text | |
---|---|
0 | Calvin Weston – drums \n |
1 | |
2 | |
3 | = = Background = = \n |
4 | The entire highway was in Kent County . \n |
5 | = = = Regular season = = = \n |
6 | On June 7 , 1911 , Madero entered Mexico City . In October 1911 he was elected president , under the banner of the Partido Constitucional Progresista , along with José María Pino Suárez , his new running mate as vice @-@ president . Madero pushed aside Francisco Vázquez Gómez , the vice presidential candidate for the Anti @-@ Reelectionist Party in 1910 , as being too moderate . \n |
7 | = = Reception and legacy = = \n |
8 | |
9 | = = Background = = \n |
As we can see, some of the texts are a full paragraph of a Wikipedia article while others are just titles or empty lines.
For causal language modeling (CLM) we are going to take all the texts in our dataset, tokenize them and concatenate them. Then we will split them into examples of a fixed sequence length. This way the model will receive chunks of contiguous text that may look like:
part of text 1
or
end of text 1 [BOS_TOKEN] beginning of text 2
depending on whether they span multiple original texts or not. The labels will be the same as the inputs, shifted to the right.
We will use the distilgpt2
model for this example. You can pick any of the checkpoints listed here instead:
model_checkpoint = "distilgpt2"
To tokenize all our texts with the same vocabulary that was used when training the model, we have to download a pretrained tokenizer. This is all done by the AutoTokenizer
class:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
We can now call the tokenizer on all our texts. This is very simple, using the map
method from the Datasets library. First we define a function that calls the tokenizer on our texts:
def tokenize_function(examples):
return tokenizer(examples["text"])
Then we apply it to all the splits in our datasets
object, using batched=True
and 4 processes to speed up the preprocessing. We won't need the text
column afterward, so we discard it.
tokenized_datasets = datasets.map(
tokenize_function, batched=True, num_proc=4, remove_columns=["text"]
)
Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-34f24c4e530ba86d.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0aa615b6c36c5d9b.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-e504c90124f8ff4a.arrow
If we now look at an element of our datasets, we will see the text have been replaced by the input_ids
the model will need:
tokenized_datasets["train"][1]
{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [796, 569, 18354, 7496, 17740, 6711, 796, 220, 198]}
Now for the harder part: We need to concatenate all our texts together, and then split the result into chunks of a fixed size, which we will call block_size
. To do this, we will use the map
method again, with the option batched=True
. When we use batched=True
, the function we pass to map()
will be passed multiple inputs at once, allowing us to group them into more or fewer examples than we had in the input. This allows us to create our new fixed-length samples.
We can use any block_size
up to the the maximum length our model was pretrained with, which for models in the gpt2
family is usually something in the range 512-1024. This might be a bit too big to fit in your GPU RAM, though, so let's use something a bit smaller: 128.
# block_size = tokenizer.model_max_length
block_size = 128
Then we write the preprocessing function that will group our texts:
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, though you could add padding instead if the model supports it
# In this, as in all things, we advise you to follow your heart
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
Note that we duplicate the inputs for our labels, without shifting them, even though we told you the labels need to be shifted! This is because CausalLM models in the 🤗 Transformers library automatically apply right-shifting to the inputs, so we don't need to do it manually.
Also note that by default, the map
method will send a batch of 1,000 examples to be treated by the preprocessing function. So here, we will drop the remainder to make the concatenated tokenized texts a multiple of block_size
every 1,000 examples. You can adjust this behavior by passing a higher batch size (which will also be processed slower). You can also speed-up the preprocessing by using multiprocessing:
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
batch_size=1000,
num_proc=4,
)
And we can check our datasets have changed: now the samples contain chunks of block_size
contiguous tokens, potentially spanning several of our original texts.
tokenizer.decode(lm_datasets["train"][1]["input_ids"])
' game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries, along with Valkyria Chronicles II director Takeshi Oz'
Now that the data has been cleaned, we're ready to initialize our model:
from transformers import TFAutoModelForCausalLM
model = TFAutoModelForCausalLM.from_pretrained(model_checkpoint)
2022-01-28 13:20:57.249874: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:57.286917: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:57.287943: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:57.290434: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-01-28 13:20:57.293611: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:57.294262: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:57.294888: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:58.044562: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:58.045240: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:58.045863: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-01-28 13:20:58.046894: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21878 MB memory: -> device: 0, name: GeForce RTX 3090, pci bus id: 0000:21:00.0, compute capability: 8.6 2022-01-28 13:20:58.224363: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. 2022-01-28 13:20:59.137247: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. All model checkpoint layers were used when initializing TFGPT2LMHeadModel. All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at distilgpt2. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.
Once we've done that, it's time for our optimizer! We can initialize our AdamWeightDecay
optimizer directly, or we can use the create_optimizer
function to generate an AdamWeightDecay
optimizer with a learning rate schedule. In this case, we'll just stick with a constant learning rate for simplicity, so let's just use AdamWeightDecay
.
from transformers import create_optimizer, AdamWeightDecay
optimizer = AdamWeightDecay(lr=2e-5, weight_decay_rate=0.01)
/home/matt/miniconda3/envs/tensorflow28/lib/python3.10/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. super(Adam, self).__init__(name, **kwargs)
Note that most models on the Hub compute loss internally, so we actually don't have to specify anything there! Leaving the loss field blank will cause the model to read the loss
head as its loss value.
This is an unusual quirk of TensorFlow models in 🤗 Transformers, so it's worth elaborating on in a little more detail. All 🤗 Transformers models are capable of computing an appropriate loss for their task internally (for example, a CausalLM model will use a cross-entropy loss). To do this, the labels must be provided in the input dict (or equivalently, in the columns
argument to to_tf_dataset()
), so that they are visible to the model during the forward pass.
This is quite different from the standard Keras way of handling losses, where labels are passed separately and not visible to the main body of the model, and loss is handled by a function that the user passes to compile()
, which uses the model outputs and the label to compute a loss value.
The approach we take is that if the user does not pass a loss to compile()
, the model will assume you want the internal loss. If you are doing this, you should make sure that the labels column(s) are included in the input dict or in the columns
argument to to_tf_dataset
.
If you want to use your own loss, that is of course possible too! If you do this, you should make sure your labels column(s) are passed like normal labels, either as the second argument to model.fit()
, or in the label_cols
argument to to_tf_dataset
.
import tensorflow as tf
model.compile(optimizer=optimizer)
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! Please ensure your labels are passed as keys in the input dict so that they are accessible to the model during the forward pass. To disable this behaviour, please pass a loss argument, or explicitly pass loss=None if you do not want your model to compute a loss.
Next, we convert our datasets to tf.data.Dataset
, which Keras understands natively. Dataset
objects have a built-in method for this. Because all our inputs are the same length, no padding is required, so we can use the DefaultDataCollator. Note that our data collators are designed to work for multiple frameworks, so ensure you set the return_tensors='tf'
argument to get Tensorflow tensors out - you don't want to accidentally get a load of torch.Tensor
objects in the middle of your nice TF code!
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator(return_tensors="tf")
train_set = lm_datasets["train"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
validation_set = lm_datasets["validation"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
Now we can train our model. We can also add a callback to sync up our model with the Hub - this allows us to resume training from other machines and even test the model's inference quality midway through training! If you don't want to do this, simply remove the callbacks argument in the call to fit()
.
from transformers.keras_callbacks import PushToHubCallback
from tensorflow.keras.callbacks import TensorBoard
model_name = model_checkpoint.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-wikitext2"
tensorboard_callback = TensorBoard(log_dir="./clm_model_save/logs")
push_to_hub_callback = PushToHubCallback(
output_dir="./clm_model_save",
tokenizer=tokenizer,
hub_model_id=push_to_hub_model_id,
)
callbacks = [tensorboard_callback, push_to_hub_callback]
model.fit(train_set, validation_data=validation_set, epochs=1, callbacks=callbacks)
/home/matt/PycharmProjects/notebooks/examples/clm_model_save is already a clone of https://huggingface.co/Rocketknight1/distilgpt2-finetuned-wikitext2. Make sure you pull the latest changes with `repo.git_pull()`.
6/1166 [..............................] - ETA: 1:59 - loss: 4.4879WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0517s vs `on_train_batch_end` time: 0.0720s). Check your callbacks. 1166/1166 [==============================] - 140s 114ms/step - loss: 3.8577 - val_loss: 3.6752
<keras.callbacks.History at 0x7f95742a9db0>
Once the training is completed, we can evaluate our model and get its cross-entropy loss on the validation set like this:
eval_loss = model.evaluate(validation_set)
121/121 [==============================] - 4s 33ms/step - loss: 3.6752
The quality of language models is often measured in 'perplexity' rather than cross-entropy. To convert to perplexity, we simply raise e to the power of the cross-entropy loss.
import math
print(f"Perplexity: {math.exp(eval_loss):.2f}")
Perplexity: 39.46
If you saved the model with the callback, you can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked"
so for instance:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("sgugger/my-awesome-model")
For masked language modeling (MLM) we are going to use the same preprocessing as before for our dataset with one additional step: we will randomly mask some tokens (by replacing them by [MASK]
) and the labels will be adjusted to only include the masked tokens (we don't have to predict the non-masked tokens).
We will use the distilroberta-base
model for this example. You can pick any of the checkpoints listed here instead:
model_checkpoint = "distilroberta-base"
We can apply the same tokenization function as before, we just need to update our tokenizer to use the checkpoint we just picked. Don't panic about the warnings about inputs being too long for the model - remember that we'll be breaking them into shorter chunks right afterwards!
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenized_datasets = datasets.map(
tokenize_function, batched=True, num_proc=4, remove_columns=["text"]
)
Token indices sequence length is longer than the specified maximum sequence length for this model (544 > 512). Running this sequence through the model will result in indexing errors Token indices sequence length is longer than the specified maximum sequence length for this model (560 > 512). Running this sequence through the model will result in indexing errors Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors Token indices sequence length is longer than the specified maximum sequence length for this model (638 > 512). Running this sequence through the model will result in indexing errors Token indices sequence length is longer than the specified maximum sequence length for this model (522 > 512). Running this sequence through the model will result in indexing errors
And now, we group texts together and chunk them into samples of length block_size
. You can skip this step if your dataset is composed of individual sentences.
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
batch_size=1000,
num_proc=4,
)
The rest is very similar to what we had, with two exceptions. First we use a model suitable for masked LM:
from transformers import TFAutoModelForMaskedLM
model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)
All model checkpoint layers were used when initializing TFRobertaForMaskedLM. All the layers of TFRobertaForMaskedLM were initialized from the model checkpoint at distilroberta-base. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForMaskedLM for predictions without further training.
We redefine our optimizer
as we did with the CLM model, and we compile the model. We're using the internal loss again, like we did before.
from transformers import create_optimizer, AdamWeightDecay
import tensorflow as tf
optimizer = AdamWeightDecay(lr=2e-5, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
/home/matt/miniconda3/envs/tensorflow28/lib/python3.10/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. super(Adam, self).__init__(name, **kwargs) No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! Please ensure your labels are passed as keys in the input dict so that they are accessible to the model during the forward pass. To disable this behaviour, please pass a loss argument, or explicitly pass loss=None if you do not want your model to compute a loss.
Finally, we use a special data_collator
. The data_collator
is a function that is responsible for taking the samples and batching them in tensors. In the previous example, we had nothing special to do, so we just used the default for this argument. Here we want to randomly mask tokens. We could do it as a pre-processing step (like the tokenization) but then the tokens would always be masked the same way at each epoch. By doing this step inside the data_collator
, we ensure this random masking is done in a new way each time we go over the data.
To do this masking for us, the library provides a DataCollatorForLanguageModeling
. We can adjust the probability of the masking. Note that our data collators are designed to work for multiple frameworks, so ensure you set the return_tensors='tf'
argument to get Tensorflow tensors out - you don't want to accidentally get a load of torch.Tensor
objects in the middle of your nice TF code!
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm_probability=0.15, return_tensors="tf"
)
Now we generate our datasets as before. Remember to pass the data_collator
you just made to the collate_fn
argument.
train_set = lm_datasets["train"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
validation_set = lm_datasets["validation"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
And now we fit our model! As before, we can use a callback to sync with the hub during training. You can remove this if you don't want to!
from transformers.keras_callbacks import PushToHubCallback
model_name = model_checkpoint.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-wikitext2"
callback = PushToHubCallback(
output_dir="./mlm_model_save",
tokenizer=tokenizer,
hub_model_id=push_to_hub_model_id,
)
model.fit(train_set, validation_data=validation_set, epochs=1, callbacks=[callback])
/home/matt/PycharmProjects/notebooks/examples/mlm_model_save is already a clone of https://huggingface.co/Rocketknight1/distilroberta-base-finetuned-wikitext2. Make sure you pull the latest changes with `repo.git_pull()`.
1202/1202 [==============================] - ETA: 0s - loss: 1.9043
Several commits (2) will be pushed upstream.
1202/1202 [==============================] - 138s 110ms/step - loss: 1.9043 - val_loss: 1.7174
<keras.callbacks.History at 0x7f96e3be36a0>
Like before, we can evaluate our model on the validation set and compute perplexity. The perplexity is much lower than for the CLM objective because for the MLM objective, we only have to make predictions for the masked tokens (which represent 15% of the total here) while having access to the rest of the tokens. It's thus an easier task for the model.
import math
eval_results = model.evaluate(validation_set)
print(f"Perplexity: {math.exp(eval_results):.2f}")
125/125 [==============================] - 4s 32ms/step - loss: 1.7101 Perplexity: 5.53
If you used the callback, you can now share this model with all your friends, family or favorite pets: they can all load it with the identifier "your-username/the-name-you-picked"
so for instance:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("your-username/my-awesome-model")