If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it.
#! pip install datasets transformers
If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.
To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.
First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:
from huggingface_hub import notebook_login
notebook_login()
Then you need to install Git-LFS. Uncomment the following instructions:
# !apt install git-lfs
Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:
import transformers
print(transformers.__version__)
You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs here.
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
from transformers.utils import send_example_telemetry
send_example_telemetry("language_modeling_notebook", framework="pytorch")
In this notebook, we'll see how to fine-tune one of the 🤗 Transformers model on a language modeling tasks. We will cover two types of language modeling tasks which are:
We will see how to easily load and preprocess the dataset for each one of those tasks, and how to use the Trainer
API to fine-tune a model on it.
A script version of this notebook you can directly run on a distributed environment or on TPU is available in our examples folder.
For each of those tasks, we will use the Wikitext 2 dataset as an example. You can load it very easily with the 🤗 Datasets library.
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
Reusing dataset wikitext (/home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)
You can replace the dataset above with any dataset hosted on the hub or use your own files. Just uncomment the following cell and replace the paths with values that will lead to your files:
# datasets = load_dataset("text", data_files={"train": path_to_train.txt, "validation": path_to_validation.txt}
You can also load datasets from a csv or a JSON file, see the full documentation for more information.
To access an actual element, you need to select a split first, then give an index:
datasets["train"][10]
{'text': ' The game \'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and " Battle Potentials " , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique " Masters Table " , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate " Direct Command " and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her " Valkyria Form " and become invincible , while Imca can target multiple enemy units with her heavy weapon . \n'}
To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
for column, typ in dataset.features.items():
if isinstance(typ, ClassLabel):
df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
show_random_elements(datasets["train"])
text | |
---|---|
0 | DuVall joined Alice in Chains as lead singer during the band 's reunion concerts . Duff McKagan again joined the band for the reunion tour , playing rhythm guitar on selected songs . Before the tour , Kinney mentioned in an interview that he would be interested in writing new material , but not as Alice in Chains . However , AliceinChains.com reported that the band had begun writing new material , with DuVall on lead vocals . \n |
1 | In chess tournaments and matches , the frequency with which each player receives white and black is an important consideration . In matches , the players ' colors in the first game are determined by drawing lots , and alternated thereafter . In round robin tournaments with an odd number of players , each player receives an equal number of whites and blacks ; with an even number of players , each receives one extra white or black . Where one or more players withdraws from the tournament , the tournament director may change the assigned colors in some games so that no player receives two more blacks than whites , or vice versa . The double @-@ round robin tournament is considered to give the most reliable final standings , since each player receives the same number of whites and blacks , and plays both White and Black against each opponent . \n |
2 | = = = Rich Rodriguez era = = = \n |
3 | = = = = First siege of Kaifeng = = = = \n |
4 | Furthermore , it appears that Iguanodon became more quadrupedal as it got older and heavier ; juvenile I. bernissartensis have shorter arms than adults ( 60 % of hindlimb length versus 70 % for adults ) . When walking as a quadruped , the animal 's hands would have been held so that the palms faced each other , as shown by iguanodontian trackways and the anatomy of this genus 's arms and hands . The three toed pes ( foot ) of Iguanodon was relatively long , and when walking , both the hand and the foot would have been used in a digitigrade fashion ( walking on the fingers and toes ) . The maximum speed of Iguanodon has been estimated at 24 km / h ( 15 mph ) , which would have been as a biped ; it would not have been able to gallop as a quadruped . \n |
5 | |
6 | With little guidance on water allocation from the Supreme Court , proponents of the dam feared endless litigation . A Colorado attorney proposed that the seven states which fell within the river 's basin ( California , Nevada , Arizona , Utah , New Mexico , Colorado and Wyoming ) form an interstate compact , with the approval of Congress . Such compacts were authorized by Article I of the United States Constitution but had never been concluded among more than two states . In 1922 , representatives of seven states met with then @-@ Secretary of Commerce Herbert Hoover . Initial talks produced no result , but when the Supreme Court handed down the Wyoming v. Colorado decision undermining the claims of the upstream states , they became anxious to reach an agreement . The resulting Colorado River Compact was signed on November 24 , 1922 . \n |
7 | He goes on to note that , once freed from the possession , the victim will never again be able to eat tofu , azukimeshi , or other foods favored by foxes : \n |
8 | The only legal marriage was between Kody and his first wife Meri , until their legal divorce in September 2014 . ( He legally married fourth wife Robyn in December 2014 in order to legally adopt her three children ) . The other marriages are considered spiritual unions . As of 2015 Kody has been married to Meri for 25 years , Janelle for 22 years , Christine for 21 years , and Robyn for 5 years . Kody and Meri have a daughter named Mariah , their only child . Kody and Janelle have six children : daughters Madison and Savanah and sons Logan , Hunter , Garrison , and Gabriel . Kody and Christine have six children : daughters Aspyn , Mykelti , Gwendlyn , Ysabel , and Truely and son Paedon . Robyn had three children from her first marriage , which was monogamous : Dayton , Aurora , and Breanna . Kody legally adopted them in June 2015 . Kody and Robyn have two children : son Solomon and daughter Ariella . \n |
9 | = = Ranks and insignia = = \n |
As we can see, some of the texts are a full paragraph of a Wikipedia article while others are just titles or empty lines.
For causal language modeling (CLM) we are going to take all the texts in our dataset and concatenate them after they are tokenized. Then we will split them in examples of a certain sequence length. This way the model will receive chunks of contiguous text that may look like:
part of text 1
or
end of text 1 [BOS_TOKEN] beginning of text 2
depending on whether they span over several of the original texts in the dataset or not. The labels will be the same as the inputs, shifted to the left.
We will use the distilgpt2
model for this example. You can pick any of the checkpoints listed here instead:
model_checkpoint = "distilgpt2"
To tokenize all our texts with the same vocabulary that was used when training the model, we have to download a pretrained tokenizer. This is all done by the AutoTokenizer
class:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
We can now call the tokenizer on all our texts. This is very simple, using the map
method from the Datasets library. First we define a function that call the tokenizer on our texts:
def tokenize_function(examples):
return tokenizer(examples["text"])
Then we apply it to all the splits in our datasets
object, using batched=True
and 4 processes to speed up the preprocessing. We won't need the text
column afterward, so we discard it.
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-0a686d6f64cb210f.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-659bcb80cad0097c.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-7f22912475d34c88.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-b3566e2fe9c5c036.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-a7e5c438d154ad4d.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-012753d728f48615.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-1d058a410941ea13.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-d3f134f9a6ca08b5.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-1fac954e1a542924.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-f3cc2c4468f35c7e.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-997632d4bc92501b.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-657bf68aa1caa201.arrow
If we now look at an element of our datasets, we will see the text have been replaced by the input_ids
the model will need:
tokenized_datasets["train"][1]
{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [796, 569, 18354, 7496, 17740, 6711, 796, 220, 198]}
Now for the harder part: we need to concatenate all our texts together then split the result in small chunks of a certain block_size
. To do this, we will use the map
method again, with the option batched=True
. This option actually lets us change the number of examples in the datasets by returning a different number of examples than we got. This way, we can create our new samples from a batch of examples.
First, we grab the maximum length our model was pretrained with. This might be a big too big to fit in your GPU RAM, so here we take a bit less at just 128.
# block_size = tokenizer.model_max_length
block_size = 128
Then we write the preprocessing function that will group our texts:
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
First note that we duplicate the inputs for our labels. This is because the model of the 🤗 Transformers library apply the shifting to the right, so we don't need to do it manually.
Also note that by default, the map
method will send a batch of 1,000 examples to be treated by the preprocessing function. So here, we will drop the remainder to make the concatenated tokenized texts a multiple of block_size
every 1,000 examples. You can adjust this behavior by passing a higher batch size (which will also be processed slower). You can also speed-up the preprocessing by using multiprocessing:
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
batch_size=1000,
num_proc=4,
)
Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-da77bf362d4c6fa4.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-7d08a6d62516c9ff.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-a985b575c96ddae3.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-47fffef35acafddb.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-9281db6733359e2e.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-6c06eb34e3ec189c.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-21ab2ae5196fbb0c.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-b014c0146a937326.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-3dbc3c09d8c329a2.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-a92a7c440f530426.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-9f8b3fe14e22eebc.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-91c30819fb1d1769.arrow
And we can check our datasets have changed: now the samples contain chunks of block_size
contiguous tokens, potentially spanning over several of our original texts.
tokenizer.decode(lm_datasets["train"][1]["input_ids"])
' game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries, along with Valkyria Chronicles II director Takeshi Oz'
Now that the data has been cleaned, we're ready to instantiate our Trainer
. We will a model:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
And some TrainingArguments
:
from transformers import Trainer, TrainingArguments
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
f"{model_name}-finetuned-wikitext2",
evaluation_strategy = "epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=True,
)
The last argument to setup everything so we can push the model to the Hub regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the hub_model_id
argument to set the repo name (it needs to be the full name, including your namespace: for instance "sgugger/gpt-finetuned-wikitext2"
or "huggingface/gpt-finetuned-wikitext2"
).
We pass along all of those to the Trainer
class:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets["validation"],
)
And we can train our model:
trainer.train()
Epoch | Training Loss | Validation Loss |
---|---|---|
1 | 3.767790 | 3.664047 |
2 | 3.648033 | 3.645540 |
3 | 3.594949 | 3.641928 |
TrainOutput(global_step=7002, training_loss=3.69408126695944)
Once the training is completed, we can evaluate our model and get its perplexity on the validation set like this:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
Perplexity: 38.17
You can now upload the result of the training to the Hub, just execute this instruction:
trainer.push_to_hub()
You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked"
so for instance:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("sgugger/my-awesome-model")
For masked language modeling (MLM) we are going to use the same preprocessing as before for our dataset with one additional step: we will randomly mask some tokens (by replacing them by [MASK]
) and the labels will be adjusted to only include the masked tokens (we don't have to predict the non-masked tokens).
We will use the distilroberta-base
model for this example. You can pick any of the checkpoints listed here instead:
model_checkpoint = "distilroberta-base"
We can apply the same tokenization function as before, we just need to update our tokenizer to use the checkpoint we just picked:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-333e4baa6f280a66.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-23acd0930cc16da7.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-56ae8ad41a9fdf19.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-599a47a0e666ad65.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-5ba855475cae1614.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-3aa432d7bf07952c.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-75711d299cf5397c.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-46595b48d78f0089.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-9bd1f40499aceff5.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-8dcfc6147d5fa31f.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-2995af94c706ccfc.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-6e6c35e31dcb428c.arrow
And like before, we group texts together and chunk them in samples of length block_size
. You can skip that step if your dataset is composed of individual sentences.
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
batch_size=1000,
num_proc=4,
)
Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-661796332aa2b576.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-e019d91824c225fd.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-b5875c725d0e5cb7.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-a8e3eeaa703ca023.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-5b01f983be82f7d4.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-f21374c1e117e44a.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-9423e584c93db0b1.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-519591986bd5de0e.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-1567968e11ef94e5.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-505a19c89d31abc3.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-1ab0ccc328411d63.arrow Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-1c817d18f58aea63.arrow
The rest is very similar to what we had, with two exceptions. First we use a model suitable for masked LM:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['lm_head.decoder.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
We redefine our TrainingArguments
:
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
f"{model_name}-finetuned-wikitext2",
evaluation_strategy = "epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=True,
)
Like before, the last argument to setup everything so we can push the model to the Hub regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the hub_model_id
argument to set the repo name (it needs to be the full name, including your namespace: for instance "sgugger/bert-finetuned-wikitext2"
or "huggingface/bert-finetuned-wikitext2"
).
Finally, we use a special data_collator
. The data_collator
is a function that is responsible of taking the samples and batching them in tensors. In the previous example, we had nothing special to do, so we just used the default for this argument. Here we want to do the random-masking. We could do it as a pre-processing step (like the tokenization) but then the tokens would always be masked the same way at each epoch. By doing this step inside the data_collator
, we ensure this random masking is done in a new way each time we go over the data.
To do this masking for us, the library provides a DataCollatorForLanguageModeling
. We can adjust the probability of the masking:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
Then we just have to pass everything to Trainer
and begin training:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets["validation"],
data_collator=data_collator,
)
trainer.train()
Epoch | Training Loss | Validation Loss |
---|---|---|
1 | 2.109144 | 1.919501 |
2 | 1.986566 | 1.884295 |
3 | 1.952879 | 1.861102 |
TrainOutput(global_step=7218, training_loss=2.0377309222603213)
Like before, we can evaluate our model on the validation set. The perplexity is much lower than for the CLM objective because for the MLM objective, we only have to make predictions for the masked tokens (which represent 15% of the total here) while having access to the rest of the tokens. It's thus an easier task for the model.
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
Perplexity: 6.37
You can now upload the result of the training to the Hub, just execute this instruction:
trainer.push_to_hub()
You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked"
so for instance:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("sgugger/my-awesome-model")