#!/usr/bin/env python # coding: utf-8 # If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and execute it: # In[ ]: #! pip install datasets transformers[sentencepiece] # If you're opening this notebook locally, make sure your environment has an install from the last version of Datasets and a source install of Transformers. # In[ ]: from transformers.utils import send_example_telemetry send_example_telemetry("tokenizer_training_notebook", framework="none") # ## Loading the dataset # # Training your own tokenizer from scratch # In this notebook, we will see several ways to train your own tokenizer from scratch on a given corpus, so you can then use it to train a language model from scratch. # # Why would you need to *train* a tokenizer? That's because Transformer models very often use subword tokenization algorithms, and they need to be trained to identify the parts of words that are often present in the corpus you are using. We recommend you take a look at the [tokenization chapter](https://huggingface.co/course/chapter2/4?fw=pt) of the Hugging Face course for a general introduction on tokenizers, and at the [tokenizers summary](https://huggingface.co/transformers/tokenizer_summary.html) for a look at the differences between the subword tokenization algorithms. # ## Getting a corpus # We will need texts to train our tokenizer. We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download our text data, which can be easily done with the `load_dataset` function: # In[ ]: from datasets import load_dataset # For this example, we will use Wikitext-2 (which contains 4.5MB of texts so training goes fast for our example) but you can use any dataset you want (and in any language, just not English). # In[ ]: dataset = load_dataset("wikitext", name="wikitext-2-raw-v1", split="train") # We can have a look at the dataset, which as 36,718 texts: # In[ ]: dataset # To access an element, we just have to provide its index: # In[ ]: dataset[1] # We can also access a slice directly, in which case we get a dictionary with the key `"text"` and a list of texts as value: # In[ ]: dataset[:5] # The API to train our tokenizer will require an iterator of batch of texts, for instance a list of list of texts: # In[ ]: batch_size = 1000 all_texts = [dataset[i : i + batch_size]["text"] for i in range(0, len(dataset), batch_size)] # To avoid loading everything into memory (since the Datasets library keeps the element on disk and only load them in memory when requested), we define a Python iterator. This is particularly useful if you have a huge dataset: # In[ ]: def batch_iterator(): for i in range(0, len(dataset), batch_size): yield dataset[i : i + batch_size]["text"] # Now let's see how we can use this corpus to train a new tokenizer! There are two APIs to do this: the first one uses an existing tokenizer and will train a new version of it on your corpus in one line of code, the second is to actually build your tokenizer block by block, so lets you customize every step! # ## Using an existing tokenizer # If you want to train a tokenizer with the exact same algorithms and parameters as an existing one, you can just use the `train_new_from_iterator` API. For instance, let's train a new version of the GPT-2 tokenzier on Wikitext-2 using the same tokenization algorithm. # # First we need to load the tokenizer we want to use as a model: # In[ ]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") # Make sure that the tokenizer you picked as a *fast* version (backed by the 🤗 Tokenizers library) otherwise the rest of the notebook will not run: # In[ ]: tokenizer.is_fast # Then we feed the training corpus (either the list of list or the iterator we defined earlier) to the `train_new_from_iterator` method. We also have to specify the vocabulary size we want to use: # In[ ]: new_tokenizer = tokenizer.train_new_from_iterator(batch_iterator(), vocab_size=25000) # And that's all there is to it! The training goes very fast thanks to the 🤗 Tokenizers library, backed by Rust. # # You now have a new tokenizer ready to preprocess your data and train a language model. You can feed it input texts as usual: # In[ ]: new_tokenizer(dataset[:5]["text"]) # You can save it locally with the `save_pretrained` method: # In[ ]: new_tokenizer.save_pretrained("my-new-tokenizer") # Or even push it to the [Hugging Face Hub](https://huggingface.co/models) to use that new tokenzier from anywhere. Just make sure you have your authentication token stored by executing `huggingface-cli login` in a terminal or executing the following cell: # In[ ]: from huggingface_hub import notebook_login notebook_login() # We are almost there, it is also necessary that you have `git lfs` installed. You can do it directly from this notebook by uncommenting the following cells: # In[ ]: # !apt install git-lfs # In[ ]: new_tokenizer.push_to_hub("my-new-shiny-tokenizer") # The tokenizer can now be reloaded on this machine with: # In[ ]: tok = new_tokenizer.from_pretrained("my-new-tokenizer") # Or from anywhere using the repo ID, which is your namespace followed by a slash an the name you gave in the `push_to_hub` method, so for instance: # # ```python # tok = new_tokenizer.from_pretrained("sgugger/my-new-shiny-tokenizer") # ``` # Now if you want to create and a train a new tokenizer that doesn't look like anything in existence, you will need to build it from scratch using the 🤗 Tokenizers library. # ## Building your tokenizer from scratch # To understand how to build your tokenizer from scratch, we have to dive a little bit more in the 🤗 Tokenizers library and the tokenization pipeline. This pipeline takes several steps: # # - **Normalization**: Executes all the initial transformations over the initial input string. For example when you need to lowercase some text, maybe strip it, or even apply one of the common unicode normalization process, you will add a Normalizer. # - **Pre-tokenization**: In charge of splitting the initial input string. That's the component that decides where and how to pre-segment the origin string. The simplest example would be to simply split on spaces. # - **Model**: Handles all the sub-token discovery and generation, this is the part that is trainable and really dependent of your input data. # - **Post-Processing**: Provides advanced construction features to be compatible with some of the Transformers-based SoTA models. For instance, for BERT it would wrap the tokenized sentence around [CLS] and [SEP] tokens. # # And to go in the other direction: # # - **Decoding**: In charge of mapping back a tokenized input to the original string. The decoder is usually chosen according to the `PreTokenizer` we used previously. # # For the training of the model, the 🤗 Tokenizers library provides a `Trainer` class that we will use. # # All of these building blocks can be combined to create working tokenization pipelines. To give you some examples, we will show three full pipelines here: how to replicate GPT-2, BERT and T5 (which will give you an example of BPE, WordPiece and Unigram tokenizer). # ### WordPiece model like BERT # Let's have a look at how we can create a WordPiece tokenizer like the one used for training BERT. The first step is to create a `Tokenizer` with an empty `WordPiece` model: # In[ ]: from tokenizers import decoders, models, normalizers, pre_tokenizers, processors, trainers, Tokenizer tokenizer = Tokenizer(models.WordPiece(unl_token="[UNK]")) # This `tokenizer` is not ready for training yet. We have to add some preprocessing steps: the normalization (which is optional) and the pre-tokenizer, which will split inputs into the chunks we will call words. The tokens will then be part of those words (but can't be larger than that). # # In the case of BERT, the normalization is lowercasing. Since BERT is such a popular model, it has its own normalizer: # In[ ]: tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True) # If you want to customize it, you can use the existing blocks and compose them in a sequence: here for instance we lower case, apply NFD normalization and strip the accents: # In[ ]: tokenizer.normalizer = normalizers.Sequence( [normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()] ) # There is also a `BertPreTokenizer` we can use directly. It pre-tokenizes using white space and punctuation: # In[ ]: tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() # Like for the normalizer, we can combine several pre-tokenizers in a `Sequence`. If we want to have a quick look at how it preprocesses the inputs, we can call the `pre_tokenize_str` method: # In[ ]: tokenizer.pre_tokenizer.pre_tokenize_str("This is an example!") # Note that the pre-tokenizer not only split the text into words but keeps the offsets, that is the beginning and start of each of those words inside the original text. This is what will allow the final tokenizer to be able to match each token to the part of the text that it comes from (a feature we use for question answering or token classification tasks). # # We can now train our tokenizer (the pipeline is not entirely finished but we will need a trained tokenizer to build the post-processor), we use a `WordPieceTrainer` for that. The key thing to remember is to pass along the special tokens to the trainer, as they won't be seen in the corpus. # In[ ]: special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] trainer = trainers.WordPieceTrainer(vocab_size=25000, special_tokens=special_tokens) # To actually train the tokenizer, the method looks like what we used before: we can either pass some text files, or an iterator of batches of texts: # In[ ]: tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) # Now that the tokenizer is trained, we can define the post-processor: we need to add the CLS token at the beginning and the SEP token at the end (for single sentences) or several SEP tokens (for pairs of sentences). We use a [`TemplateProcessing`](https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.processors.TemplateProcessing) to do this, which requires to know the IDs of the CLS and SEP token (which is why we waited for the training). # # So let's first grab the ids of the two special tokens: # In[ ]: cls_token_id = tokenizer.token_to_id("[CLS]") sep_token_id = tokenizer.token_to_id("[SEP]") print(cls_token_id, sep_token_id) # And here is how we can build our post processor. We have to indicate in the template how to organize the special tokens with one sentence (`$A`) or two sentences (`$A` and `$B`). The `:` followed by a number indicates the token type ID to give to each part. # In[ ]: tokenizer.post_processor = processors.TemplateProcessing( single=f"[CLS]:0 $A:0 [SEP]:0", pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", special_tokens=[ ("[CLS]", cls_token_id), ("[SEP]", sep_token_id), ], ) # We can check we get the expected results by encoding a pair of sentences for instance: # In[ ]: encoding = tokenizer.encode("This is one sentence.", "With this one we have a pair.") # We can look at the tokens to check the special tokens have been inserted in the right places: # In[ ]: encoding.tokens # And we can check the token type ids are correct: # In[ ]: encoding.type_ids # The last piece in this tokenizer is the decoder, we use a `WordPiece` decoder and indicate the special prefix `##`: # In[ ]: tokenizer.decoder = decoders.WordPiece(prefix="##") # Now that our tokenizer is finished, we need to wrap it inside a Transformers object to be able to use it with the Transformers library. More specifically, we have to put it inside the class of tokenizer fast corresponding to the model we want to use, here a `BertTokenizerFast`: # In[ ]: from transformers import BertTokenizerFast new_tokenizer = BertTokenizerFast(tokenizer_object=tokenizer) # And like before, we can use this tokenizer as a normal Transformers tokenizer, and use the `save_pretrained` or `push_to_hub` methods. # # If the tokenizer you are building does not match any class in Transformers because it's really special, you can wrap it in `PreTrainedTokenizerFast`. # ### BPE model like GPT-2 # Let's now have a look at how we can create a BPE tokenizer like the one used for training GPT-2. The first step is to create a `Tokenizer` with an empty `BPE` model: # In[ ]: tokenizer = Tokenizer(models.BPE()) # Like before, we have to add the optional normalization (not used in the case of GPT-2) and we need to specify a pre-tokenizer before training. In the case of GPT-2, the pre-tokenizer used is a byte level pre-tokenizer: # In[ ]: tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) # If we want to have a quick look at how it preprocesses the inputs, we can call the `pre_tokenize_str` method: # In[ ]: tokenizer.pre_tokenizer.pre_tokenize_str("This is an example!") # We used the same default as for GPT-2 for the prefix space, so you can see that each word gets an initial `'Ġ'` added at the beginning, except the first one. # # We can now train our tokenizer! This time we use a `BpeTrainer`. # In[ ]: trainer = trainers.BpeTrainer(vocab_size=25000, special_tokens=["<|endoftext|>"]) tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) # To finish the whole pipeline, we have to include the post-processor and decoder: # In[ ]: tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) tokenizer.decoder = decoders.ByteLevel() # And like before, we finish by wrapping this in a Transformers tokenizer object: # In[ ]: from transformers import GPT2TokenizerFast new_tokenizer = GPT2TokenizerFast(tokenizer_object=tokenizer) # ### Unigram model like Albert # Let's now have a look at how we can create a Unigram tokenizer like the one used for training T5. The first step is to create a `Tokenizer` with an empty `Unigram` model: # In[ ]: tokenizer = Tokenizer(models.Unigram()) # Like before, we have to add the optional normalization (here some replaces and lower-casing) and we need to specify a pre-tokenizer before training. The pre-tokenizer used is a `Metaspace` pre-tokenizer: it replaces all spaces by a special character (defaulting to ▁) and then splits on that character. # In[ ]: tokenizer.normalizer = normalizers.Sequence( [normalizers.Replace("``", '"'), normalizers.Replace("''", '"'), normalizers.Lowercase()] ) tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() # If we want to have a quick look at how it preprocesses the inputs, we can call the `pre_tokenize_str` method: # In[ ]: tokenizer.pre_tokenizer.pre_tokenize_str("This is an example!") # You can see that each word gets an initial `▁` added at the beginning, as is usually done by sentencepiece. # # We can now train our tokenizer! This time we use a `UnigramTrainer`."We have to explicitely set the unknown token in this trainer otherwise it will forget it afterward. # In[ ]: trainer = trainers.UnigramTrainer(vocab_size=25000, special_tokens=["[CLS]", "[SEP]", "", "", "[MASK]"], unk_token="") tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) # To finish the whole pipeline, we have to include the post-processor and decoder. The post-processor is very similar to what we saw with BERT, the decoder is just `Metaspace`, like for the pre-tokenizer. # In[ ]: cls_token_id = tokenizer.token_to_id("[CLS]") sep_token_id = tokenizer.token_to_id("[SEP]") # In[ ]: tokenizer.post_processor = processors.TemplateProcessing( single="[CLS]:0 $A:0 [SEP]:0", pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", special_tokens=[ ("[CLS]", cls_token_id), ("[SEP]", sep_token_id), ], ) tokenizer.decoder = decoders.Metaspace() # And like before, we finish by wrapping this in a Transformers tokenizer object: # In[ ]: from transformers import AlbertTokenizerFast new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) # ## Use your new tokenizer to train a language model! # You can either use your new tokenizer in the language modeling from scratch notebook [Link to come] or use the `--tokenizer_name` argument in the [language modeling scripts](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling) to use it there to train a model from scratch. # In[ ]: