#!/usr/bin/env python # coding: utf-8 # If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it. We also use the `sacrebleu` and `sentencepiece` libraries - you may need to install these even if you already have 🤗 Transformers! # In[ ]: #! pip install transformers[sentencepiece] datasets #! pip install sacrebleu sentencepiece #! pip install huggingface_hub # If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries. # # To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow. # # First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then uncomment the following cell and input your token: # In[ ]: from huggingface_hub import notebook_login notebook_login() # Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email: # In[ ]: # !apt install git-lfs # !git config --global user.email "you@example.com" # !git config --global user.name "Your Name" # Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version: # In[1]: import transformers print(transformers.__version__) # You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/translation). # We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely. # In[ ]: from transformers.utils import send_example_telemetry send_example_telemetry("translation_notebook", framework="tensorflow") # # Fine-tuning a model on a translation task # In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a translation task. We will use the [WMT dataset](http://www.statmt.org/wmt16/), a machine translation dataset composed from a collection of various sources, including news commentaries and parliament proceedings. # # ![Widget inference on a translation task](images/translation.png) # # We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using Keras. # In[2]: model_checkpoint = "Helsinki-NLP/opus-mt-en-ROMANCE" # This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`Helsinki-NLP/opus-mt-en-romance`](https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE) checkpoint. # ## Loading the dataset # We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the `datasets` function `load_dataset` and the `evaluate` function `load`. We use the English/Romanian part of the WMT dataset here. # In[3]: from datasets import load_dataset from evaluate import load raw_datasets = load_dataset("wmt16", "ro-en") metric = load("sacrebleu") # The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set: # In[4]: raw_datasets # To access an actual element, you need to select a split first, then give an index: # In[5]: raw_datasets["train"][0] # To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset. # In[6]: import datasets import random import pandas as pd from IPython.display import display, HTML def show_random_elements(dataset, num_examples=5): assert num_examples <= len( dataset ), "Can't pick more elements than there are in the dataset." picks = [] for _ in range(num_examples): pick = random.randint(0, len(dataset) - 1) while pick in picks: pick = random.randint(0, len(dataset) - 1) picks.append(pick) df = pd.DataFrame(dataset[picks]) for column, typ in dataset.features.items(): if isinstance(typ, datasets.ClassLabel): df[column] = df[column].transform(lambda i: typ.names[i]) display(HTML(df.to_html())) # In[7]: show_random_elements(raw_datasets["train"]) # The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric): # In[8]: metric # You can call its `compute` method with your predictions and labels, which need to be list of decoded strings (list of list for the labels): # In[9]: fake_preds = ["hello there", "general kenobi"] fake_labels = [["hello there"], ["general kenobi"]] metric.compute(predictions=fake_preds, references=fake_labels) # ## Preprocessing the data # Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires. # # To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure: # # - we get a tokenizer that corresponds to the model architecture we want to use, # - we download the vocabulary used when pretraining this specific checkpoint. # # That vocabulary will be cached, so it's not downloaded again the next time we run the cell. # In[10]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) # For the mBART tokenizer (like we have here), we need to set the source and target languages (so the texts are preprocessed properly). You can check the language codes [here](https://huggingface.co/facebook/mbart-large-cc25) if you are using this notebook on a different pairs of languages. # In[11]: if "mbart" in model_checkpoint: tokenizer.src_lang = "en-XX" tokenizer.tgt_lang = "ro-RO" # By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. # You can directly call this tokenizer on one sentence or a pair of sentences: # In[12]: tokenizer("Hello, this is a sentence!") # Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested. # # Instead of one sentence, we can pass along a list of sentences: # In[13]: tokenizer(["Hello, this is a sentence!", "This is another sentence."]) # To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets: # In[14]: with tokenizer.as_target_tokenizer(): print(tokenizer(["Hello, this is a sentence!", "This is another sentence."])) # If you are using one of the five T5 checkpoints that require a special prefix to put before the inputs, you should adapt the following cell. # In[15]: if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]: prefix = "translate English to Romanian: " else: prefix = "" # We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset. # In[16]: max_input_length = 128 max_target_length = 128 source_lang = "en" target_lang = "ro" def preprocess_function(examples): inputs = [prefix + ex[source_lang] for ex in examples["translation"]] targets = [ex[target_lang] for ex in examples["translation"]] model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer(targets, max_length=max_target_length, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs # This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key: # In[17]: preprocess_function(raw_datasets["train"][:2]) # To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command. # In[18]: tokenized_datasets = raw_datasets.map(preprocess_function, batched=True) # Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again. # # Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently. # ## Fine-tuning the model # Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us. # In[19]: from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) # Note that we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case. # Next we set some parameters like the learning rate and the `batch_size`and customize the weight decay. # # The last two arguments are to setup everything so we can push the model to the [Hub](https://huggingface.co/models) at the end of training. Remove the two of them if you didn't follow the installation steps at the top of the notebook, otherwise you can change the value of push_to_hub_model_id to something you would prefer. # In[20]: batch_size = 16 learning_rate = 2e-5 weight_decay = 0.01 num_train_epochs = 1 model_name = model_checkpoint.split("/")[-1] push_to_hub_model_id = f"{model_name}-finetuned-{source_lang}-to-{target_lang}" # Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels. Note that our data collators are designed to work for multiple frameworks, so ensure you set the `return_tensors='np'` argument to get NumPy arrays out - you don't want to accidentally get a load of `torch.Tensor` objects in the middle of your nice TF code! You could also use `return_tensors='tf'` to get TensorFlow tensors, but our TF dataset pipeline actually uses a NumPy loader internally, which is wrapped at the end with a `tf.data.Dataset`. As a result, `np` is usually more reliable and performant when you're using it! # In[1]: data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np") generation_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np", pad_to_multiple_of=128) # Next, we convert our datasets to `tf.data.Dataset`, which Keras understands natively. There are two ways to do this - we can use the slightly more low-level [`Dataset.to_tf_dataset()`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.to_tf_dataset) method, or we can use [`Model.prepare_tf_dataset()`](https://huggingface.co/docs/transformers/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset). The main difference between these two is that the `Model` method can inspect the model to determine which column names it can use as input, which means you don't need to specify them yourself. Make sure to specify the collator we just created as our `collate_fn`! # # We also want to compute `BLEU` metrics, which will require us to generate text from our model. To speed things up, we can compile our generation loop with XLA. This results in a *huge* speedup - up to 100X! The downside of XLA generation, though, is that it doesn't like variable input shapes, because it needs to run a new compilation for each new input shape! To compensate for that, let's use `pad_to_multiple_of` for the dataset we use for text generation. This will reduce the number of unique input shapes a lot, meaning we can get the benefits of XLA generation with only a few compilations. # In[22]: train_dataset = model.prepare_tf_dataset( tokenized_datasets["train"], batch_size=batch_size, shuffle=True, collate_fn=data_collator, ) validation_dataset = model.prepare_tf_dataset( tokenized_datasets["validation"], batch_size=batch_size, shuffle=False, collate_fn=data_collator, ) generation_dataset = model.prepare_tf_dataset( tokenized_datasets["validation"], batch_size=8, shuffle=False, collate_fn=generation_data_collator, ) # Now we initialize our loss and optimizer and compile the model. Note that most Transformers models compute loss internally, so we can just leave the loss argument blank to use the internal loss instead. For the optimizer, we can use the `AdamWeightDecay` optimizer in the Transformer library. # In[23]: from transformers import AdamWeightDecay import tensorflow as tf optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay) model.compile(optimizer=optimizer) # Now we can train our model. We can also add a few optional callbacks here, which you can remove if they aren't useful to you. In no particular order, these are: # - PushToHubCallback will sync up our model with the Hub - this allows us to resume training from other machines, share the model after training is finished, and even test the model's inference quality midway through training! # - TensorBoard is a built-in Keras callback that logs TensorBoard metrics. # - KerasMetricCallback is a callback for computing advanced metrics. There are a number of common metrics in NLP like ROUGE which are hard to fit into your compiled training loop because they depend on decoding predictions and labels back to strings with the tokenizer, and calling arbitrary Python functions to compute the metric. The KerasMetricCallback will wrap a metric function, outputting metrics as training progresses. # # If this is the first time you've seen `KerasMetricCallback`, it's worth explaining what exactly is going on here. The callback takes two main arguments - a `metric_fn` and an `eval_dataset`. It then iterates over the `eval_dataset` and collects the model's outputs for each sample, before passing the `list` of predictions and the associated `list` of labels to the user-defined `metric_fn`. If the `predict_with_generate` argument is `True`, then it will call `model.generate()` for each input sample instead of `model.predict()` - this is useful for metrics that expect generated text from the model, like `ROUGE` and `BLEU`. # # This callback allows complex metrics to be computed each epoch that would not function as a standard Keras Metric. Metric values are printed each epoch, and can be used by other callbacks like `TensorBoard` or `EarlyStopping`. # In[24]: from transformers.keras_callbacks import KerasMetricCallback import numpy as np def metric_fn(eval_predictions): preds, labels = eval_predictions prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds ] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # We use -100 to mask labels - replace it with the tokenizer pad token when decoding # so that no output is emitted for these labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds = [pred.strip() for pred in decoded_preds] decoded_labels = [[label.strip()] for label in decoded_labels] result = metric.compute(predictions=decoded_preds, references=decoded_labels) result = {"bleu": result["score"]} result["gen_len"] = np.mean(prediction_lens) return result metric_callback = KerasMetricCallback( metric_fn=metric_fn, eval_dataset=generation_dataset, predict_with_generate=True, use_xla_generation=True, generate_kwargs={"max_length": 128} ) # With the metric callback ready, now we can specify the other callbacks and fit our model: # In[25]: from transformers.keras_callbacks import PushToHubCallback from tensorflow.keras.callbacks import TensorBoard tensorboard_callback = TensorBoard(log_dir="./translation_model_save/logs") push_to_hub_callback = PushToHubCallback( output_dir="./translation_model_save", tokenizer=tokenizer, hub_model_id=push_to_hub_model_id, ) callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback] model.fit( train_dataset, validation_data=validation_dataset, epochs=1, callbacks=callbacks ) # If you used the callback above, you can now share this model with all your friends, family or favorite pets: they can all load it with the identifier `"your-username/the-name-you-picked"` so for instance: # # ```python # from transformers import TFAutoModelForSeq2SeqLM # # model = TFAutoModelForSeq2SeqLM.from_pretrained("your-username/my-awesome-model") # ``` # ## Inference # Now we've trained our model, let's see how we could load it and use it to translate text in future! First, let's load it from the hub. This means we can resume the code from here without needing to rerun everything above every time. # In[1]: from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM # You can of course substitute your own username and model here if you've trained and uploaded it! model_name = 'Rocketknight1/opus-mt-en-ROMANCE-finetuned-en-to-ro' tokenizer = AutoTokenizer.from_pretrained(model_name) model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) # Now let's try tokenizing some text and passing it to the model to generate a translation. Don't forget to add the "translate: " string at the start if you're using a `T5` model. # In[3]: input_text = "I'm not actually a very competent Romanian speaker, but let's try our best." if 't5' in model_name: input_text = "translate English to Romanian: " + input_text tokenized = tokenizer([input_text], return_tensors='np') out = model.generate(**tokenized, max_length=128) print(out) # Well, that's some tokens and a lot of padding! Let's decode those to see what it says, using the `skip_special_tokens` argument to skip those padding tokens: # In[6]: with tokenizer.as_target_tokenizer(): print(tokenizer.decode(out[0], skip_special_tokens=True)) # This is the point where I start wishing I'd done this example in a language I actually speak. Still, it looks good! Probably! # ## Using XLA in inference # If you just want to generate a few translations, the code above is all you need. However, generation can be **much** faster if you use XLA, and if you want to generate data in bulk, you should probably use it! If you're using XLA, though, remember that you'll need to do a new XLA compilation for every input size you pass to the model. This means that you should keep your batch size constant, and consider padding inputs to the same length, or using `pad_to_multiple_of` in your tokenizer to reduce the number of different input shapes you pass. Let's show an example of that: # In[9]: import tensorflow as tf @tf.function(jit_compile=True) def generate(inputs): return model.generate(**inputs, max_length=128) tokenized_data = tokenizer([input_text], return_tensors="np", pad_to_multiple_of=128) out = generate(tokenized_data) # In[8]: with tokenizer.as_target_tokenizer(): print(tokenizer.decode(out[0], skip_special_tokens=True)) # ## Pipeline API # The pipeline API offers a convenient shortcut for all of this, but doesn't (yet!) support XLA generation: # In[12]: from transformers import pipeline translator = pipeline('text2text-generation', model_name, framework="tf") # In[13]: translator(input_text, max_length=128) # Easy!