If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it. We also use the sacrebleu
and sentencepiece
libraries - you may need to install these even if you already have 🤗 Transformers!
#! pip install transformers[sentencepiece] datasets
#! pip install sacrebleu sentencepiece
#! pip install huggingface_hub
If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.
To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.
First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then uncomment the following cell and input your token:
from huggingface_hub import notebook_login
notebook_login()
Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:
# !apt install git-lfs
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"
Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:
import transformers
print(transformers.__version__)
4.21.0.dev0
You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs here.
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
from transformers.utils import send_example_telemetry
send_example_telemetry("translation_notebook", framework="tensorflow")
In this notebook, we will see how to fine-tune one of the 🤗 Transformers model for a translation task. We will use the WMT dataset, a machine translation dataset composed from a collection of various sources, including news commentaries and parliament proceedings.
We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using Keras.
model_checkpoint = "Helsinki-NLP/opus-mt-en-ROMANCE"
This notebook is built to run with any model checkpoint from the Model Hub as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the Helsinki-NLP/opus-mt-en-romance
checkpoint.
We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the datasets
function load_dataset
and the evaluate
function load
. We use the English/Romanian part of the WMT dataset here.
from datasets import load_dataset
from evaluate import load
raw_datasets = load_dataset("wmt16", "ro-en")
metric = load("sacrebleu")
Reusing dataset wmt16 (/home/matt/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/28ebdf8cf22106c2f1e58b2083d4b103608acd7bfdb6b14313ccd9e5bc8c313a)
0%| | 0/3 [00:00<?, ?it/s]
The dataset
object itself is DatasetDict
, which contains one key for the training, validation and test set:
raw_datasets
DatasetDict({ train: Dataset({ features: ['translation'], num_rows: 610320 }) validation: Dataset({ features: ['translation'], num_rows: 1999 }) test: Dataset({ features: ['translation'], num_rows: 1999 }) })
To access an actual element, you need to select a split first, then give an index:
raw_datasets["train"][0]
{'translation': {'en': 'Membership of Parliament: see Minutes', 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}
To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.
import datasets
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=5):
assert num_examples <= len(
dataset
), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset) - 1)
while pick in picks:
pick = random.randint(0, len(dataset) - 1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
for column, typ in dataset.features.items():
if isinstance(typ, datasets.ClassLabel):
df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
show_random_elements(raw_datasets["train"])
translation | |
---|---|
0 | {'en': '"Kosovo does not have not a positive image mainly because (the media portrays) the Serbs living in ghettoes ... and NATO helping Albanians displace the Serbs from Kosovo," Chukov said.', 'ro': '"Kosovo nu are o imagine pozitivă mai ales din cauza faptului că (presa arată) că sârbii trăiesc în ghetto-uri ... şi NATO îi ajută pe albanezi să strămute sârbii din Kosovo", a spus Chukov.'} |
1 | {'en': 'They also signed a memorandum of understanding on diplomatic consultations.', 'ro': 'Aceştia au semnat de asemenea un protocol de acord cu privire la consultaţiile diplomatice.'} |
2 | {'en': 'EU Commissioner for Home Affairs Cecilia Malmstrom said on Monday (September 20th) that Albania has made significant progress in meeting requirements for visa-free travel.', 'ro': 'Comisarul UE pentru afaceri interne, Cecilia Malmstrom, a declarat luni (20 septembrie) că Albania a făcut progrese semnificative în întrunirea condiţiilor pentru liberalizarea vizelor.'} |
3 | {'en': '13.', 'ro': '13.'} |
4 | {'en': 'But, in principle, thank you very much for what was, for me, too, a very interesting debate, and all the best.', 'ro': 'Dar, în principiu, vă mulţumesc foarte mult pentru această dezbatere care a fost foarte interesantă pentru mine şi vă urez toate cele bune.'} |
The metric is an instance of datasets.Metric
:
metric
EvaluationModule(name: "sacrebleu", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """ Produces BLEU scores along with its sufficient statistics from a source against one or more references. Args: predictions (`list` of `str`): list of translations to score. Each translation should be tokenized into a list of tokens. references (`list` of `list` of `str`): A list of lists of references. The contents of the first sub-list are the references for the first prediction, the contents of the second sub-list are for the second prediction, etc. Note that there must be the same number of references for each prediction (i.e. all sub-lists must be of the same length). smooth_method (`str`): The smoothing method to use, defaults to `'exp'`. Possible values are: - `'none'`: no smoothing - `'floor'`: increment zero counts - `'add-k'`: increment num/denom by k for n>1 - `'exp'`: exponential decay smooth_value (`float`): The smoothing value. Only valid when `smooth_method='floor'` (in which case `smooth_value` defaults to `0.1`) or `smooth_method='add-k'` (in which case `smooth_value` defaults to `1`). tokenize (`str`): Tokenization method to use for BLEU. If not provided, defaults to `'zh'` for Chinese, `'ja-mecab'` for Japanese and `'13a'` (mteval) otherwise. Possible values are: - `'none'`: No tokenization. - `'zh'`: Chinese tokenization. - `'13a'`: mimics the `mteval-v13a` script from Moses. - `'intl'`: International tokenization, mimics the `mteval-v14` script from Moses - `'char'`: Language-agnostic character-level tokenization. - `'ja-mecab'`: Japanese tokenization. Uses the [MeCab tokenizer](https://pypi.org/project/mecab-python3). lowercase (`bool`): If `True`, lowercases the input, enabling case-insensitivity. Defaults to `False`. force (`bool`): If `True`, insists that your tokenized input is actually detokenized. Defaults to `False`. use_effective_order (`bool`): If `True`, stops including n-gram orders for which precision is 0. This should be `True`, if sentence-level BLEU will be computed. Defaults to `False`. Returns: 'score': BLEU score, 'counts': Counts, 'totals': Totals, 'precisions': Precisions, 'bp': Brevity penalty, 'sys_len': predictions length, 'ref_len': reference length, Examples: Example 1: >>> predictions = ["hello there general kenobi", "foo bar foobar"] >>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]] >>> sacrebleu = evaluate.load("sacrebleu") >>> results = sacrebleu.compute(predictions=predictions, references=references) >>> print(list(results.keys())) ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] >>> print(round(results["score"], 1)) 100.0 Example 2: >>> predictions = ["hello there general kenobi", ... "on our way to ankh morpork"] >>> references = [["hello there general kenobi", "hello there !"], ... ["goodbye ankh morpork", "ankh morpork"]] >>> sacrebleu = evaluate.load("sacrebleu") >>> results = sacrebleu.compute(predictions=predictions, ... references=references) >>> print(list(results.keys())) ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] >>> print(round(results["score"], 1)) 39.8 """, stored examples: 0)
You can call its compute
method with your predictions and labels, which need to be list of decoded strings (list of list for the labels):
fake_preds = ["hello there", "general kenobi"]
fake_labels = [["hello there"], ["general kenobi"]]
metric.compute(predictions=fake_preds, references=fake_labels)
{'score': 0.0, 'counts': [4, 2, 0, 0], 'totals': [4, 2, 0, 0], 'precisions': [100.0, 100.0, 0.0, 0.0], 'bp': 1.0, 'sys_len': 4, 'ref_len': 4}
Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer
which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.
To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained
method, which will ensure:
That vocabulary will be cached, so it's not downloaded again the next time we run the cell.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
For the mBART tokenizer (like we have here), we need to set the source and target languages (so the texts are preprocessed properly). You can check the language codes here if you are using this notebook on a different pairs of languages.
if "mbart" in model_checkpoint:
tokenizer.src_lang = "en-XX"
tokenizer.tgt_lang = "ro-RO"
By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.
You can directly call this tokenizer on one sentence or a pair of sentences:
tokenizer("Hello, this is a sentence!")
{'input_ids': [4708, 2, 69, 28, 9, 8662, 84, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in this tutorial if you're interested.
Instead of one sentence, we can pass along a list of sentences:
tokenizer(["Hello, this is a sentence!", "This is another sentence."])
{'input_ids': [[4708, 2, 69, 28, 9, 8662, 84, 0], [188, 28, 823, 8662, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}
To prepare the targets for our model, we need to tokenize them inside the as_target_tokenizer
context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:
with tokenizer.as_target_tokenizer():
print(tokenizer(["Hello, this is a sentence!", "This is another sentence."]))
{'input_ids': [[14232, 244, 2, 69, 160, 6, 9, 10513, 1101, 84, 0], [13486, 6, 160, 6, 3778, 4853, 10513, 1101, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
If you are using one of the five T5 checkpoints that require a special prefix to put before the inputs, you should adapt the following cell.
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
prefix = "translate English to Romanian: "
else:
prefix = ""
We can then write the function that will preprocess our samples. We just feed them to the tokenizer
with the argument truncation=True
. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.
max_input_length = 128
max_target_length = 128
source_lang = "en"
target_lang = "ro"
def preprocess_function(examples):
inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
targets = [ex[target_lang] for ex in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:
preprocess_function(raw_datasets["train"][:2])
{'input_ids': [[37284, 8, 949, 37, 358, 31483, 0], [32818, 8, 31483, 8, 2541, 7910, 37, 358, 31483, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[1163, 8008, 7037, 26971, 37, 9, 56, 16836, 9026, 226, 15, 33834, 0], [67, 16852, 791, 9026, 896, 15, 33834, 111, 10795, 9351, 26549, 11114, 37, 9, 56, 16836, 9026, 226, 15, 33834, 0]]}
To apply this function on all the pairs of sentences in our dataset, we just use the map
method of our dataset
object we created earlier. This will apply the function on all the elements of all the splits in dataset
, so our training, validation and testing data will be preprocessed in one single command.
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/28ebdf8cf22106c2f1e58b2083d4b103608acd7bfdb6b14313ccd9e5bc8c313a/cache-f1b4cc7f6a817a09.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/28ebdf8cf22106c2f1e58b2083d4b103608acd7bfdb6b14313ccd9e5bc8c313a/cache-2dcbdf92c911af2a.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/28ebdf8cf22106c2f1e58b2083d4b103608acd7bfdb6b14313ccd9e5bc8c313a/cache-34490b3ad1e70b86.arrow
Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass load_from_cache_file=False
in the call to map
to not use the cached files and force the preprocessing to be applied again.
Note that we passed batched=True
to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.
Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the AutoModelForSeq2SeqLM
class. Like with the tokenizer, the from_pretrained
method will download and cache the model for us.
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
2022-07-25 17:49:51.571462: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.577820: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.578841: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.580434: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-07-25 17:49:51.583246: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.583929: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.584582: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.938374: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.939080: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.939739: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-25 17:49:51.940364: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21659 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:21:00.0, compute capability: 8.6 2022-07-25 17:49:53.116600: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. All model checkpoint layers were used when initializing TFMarianMTModel. All the layers of TFMarianMTModel were initialized from the model checkpoint at Helsinki-NLP/opus-mt-en-ROMANCE. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMarianMTModel for predictions without further training.
Note that we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.
Next we set some parameters like the learning rate and the batch_size
and customize the weight decay.
The last two arguments are to setup everything so we can push the model to the Hub at the end of training. Remove the two of them if you didn't follow the installation steps at the top of the notebook, otherwise you can change the value of push_to_hub_model_id to something you would prefer.
batch_size = 16
learning_rate = 2e-5
weight_decay = 0.01
num_train_epochs = 1
model_name = model_checkpoint.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-{source_lang}-to-{target_lang}"
Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels. Note that our data collators are designed to work for multiple frameworks, so ensure you set the return_tensors='np'
argument to get NumPy arrays out - you don't want to accidentally get a load of torch.Tensor
objects in the middle of your nice TF code! You could also use return_tensors='tf'
to get TensorFlow tensors, but our TF dataset pipeline actually uses a NumPy loader internally, which is wrapped at the end with a tf.data.Dataset
. As a result, np
is usually more reliable and performant when you're using it!
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np")
generation_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np", pad_to_multiple_of=128)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[1], line 1 ----> 1 data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np") 3 generation_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np", pad_to_multiple_of=128) NameError: name 'DataCollatorForSeq2Seq' is not defined
Next, we convert our datasets to tf.data.Dataset
, which Keras understands natively. There are two ways to do this - we can use the slightly more low-level Dataset.to_tf_dataset()
method, or we can use Model.prepare_tf_dataset()
. The main difference between these two is that the Model
method can inspect the model to determine which column names it can use as input, which means you don't need to specify them yourself. Make sure to specify the collator we just created as our collate_fn
!
We also want to compute BLEU
metrics, which will require us to generate text from our model. To speed things up, we can compile our generation loop with XLA. This results in a huge speedup - up to 100X! The downside of XLA generation, though, is that it doesn't like variable input shapes, because it needs to run a new compilation for each new input shape! To compensate for that, let's use pad_to_multiple_of
for the dataset we use for text generation. This will reduce the number of unique input shapes a lot, meaning we can get the benefits of XLA generation with only a few compilations.
train_dataset = model.prepare_tf_dataset(
tokenized_datasets["train"],
batch_size=batch_size,
shuffle=True,
collate_fn=data_collator,
)
validation_dataset = model.prepare_tf_dataset(
tokenized_datasets["validation"],
batch_size=batch_size,
shuffle=False,
collate_fn=data_collator,
)
generation_dataset = model.prepare_tf_dataset(
tokenized_datasets["validation"],
batch_size=8,
shuffle=False,
collate_fn=generation_data_collator,
)
Now we initialize our loss and optimizer and compile the model. Note that most Transformers models compute loss internally, so we can just leave the loss argument blank to use the internal loss instead. For the optimizer, we can use the AdamWeightDecay
optimizer in the Transformer library.
from transformers import AdamWeightDecay
import tensorflow as tf
optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)
model.compile(optimizer=optimizer)
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.
Now we can train our model. We can also add a few optional callbacks here, which you can remove if they aren't useful to you. In no particular order, these are:
If this is the first time you've seen KerasMetricCallback
, it's worth explaining what exactly is going on here. The callback takes two main arguments - a metric_fn
and an eval_dataset
. It then iterates over the eval_dataset
and collects the model's outputs for each sample, before passing the list
of predictions and the associated list
of labels to the user-defined metric_fn
. If the predict_with_generate
argument is True
, then it will call model.generate()
for each input sample instead of model.predict()
- this is useful for metrics that expect generated text from the model, like ROUGE
and BLEU
.
This callback allows complex metrics to be computed each epoch that would not function as a standard Keras Metric. Metric values are printed each epoch, and can be used by other callbacks like TensorBoard
or EarlyStopping
.
from transformers.keras_callbacks import KerasMetricCallback
import numpy as np
def metric_fn(eval_predictions):
preds, labels = eval_predictions
prediction_lens = [
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# We use -100 to mask labels - replace it with the tokenizer pad token when decoding
# so that no output is emitted for these
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
result["gen_len"] = np.mean(prediction_lens)
return result
metric_callback = KerasMetricCallback(
metric_fn=metric_fn, eval_dataset=generation_dataset, predict_with_generate=True, use_xla_generation=True,
generate_kwargs={"max_length": 128}
)
With the metric callback ready, now we can specify the other callbacks and fit our model:
from transformers.keras_callbacks import PushToHubCallback
from tensorflow.keras.callbacks import TensorBoard
tensorboard_callback = TensorBoard(log_dir="./translation_model_save/logs")
push_to_hub_callback = PushToHubCallback(
output_dir="./translation_model_save",
tokenizer=tokenizer,
hub_model_id=push_to_hub_model_id,
)
callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback]
model.fit(
train_dataset, validation_data=validation_dataset, epochs=1, callbacks=callbacks
)
/home/matt/PycharmProjects/notebooks/examples/translation_model_save is already a clone of https://huggingface.co/Rocketknight1/opus-mt-en-ROMANCE-finetuned-en-to-ro. Make sure you pull the latest changes with `repo.git_pull()`.
6/38145 [..............................] - ETA: 56:25 - loss: 5.2187WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0788s vs `on_train_batch_end` time: 0.1046s). Check your callbacks. 38145/38145 [==============================] - ETA: 0s - loss: 0.7140
2022-07-25 18:43:16.811498: I tensorflow/compiler/xla/service/service.cc:170] XLA service 0x5633dc97b3a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: 2022-07-25 18:43:16.811529: I tensorflow/compiler/xla/service/service.cc:178] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6 2022-07-25 18:43:16.943241: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:263] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable. 2022-07-25 18:43:17.816234: I tensorflow/compiler/xla/service/dynamic_dimension_inference.cc:965] Reshaping a dynamic dimension into a scalar, which has undefined behavior when input size is 0. The offending instruction is: %reshape.41 = s32[] reshape(s32[<=1]{0} %set-dimension-size.3), metadata={op_type="Equal" op_name="cond/while/map/while/map/while/cond/cond/Equal" source_file="/home/matt/PycharmProjects/transformers/src/transformers/generation_tf_logits_process.py" source_line=351} 2022-07-25 18:43:25.655895: I tensorflow/compiler/jit/xla_compilation_cache.cc:478] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 2022-07-25 18:45:51.416864: I tensorflow/compiler/xla/service/dynamic_dimension_inference.cc:965] Reshaping a dynamic dimension into a scalar, which has undefined behavior when input size is 0. The offending instruction is: %reshape.41 = s32[] reshape(s32[<=1]{0} %set-dimension-size.3), metadata={op_type="Equal" op_name="cond/while/map/while/map/while/cond/cond/Equal" source_file="/home/matt/PycharmProjects/transformers/src/transformers/generation_tf_logits_process.py" source_line=351} Several commits (2) will be pushed upstream.
38145/38145 [==============================] - 3382s 88ms/step - loss: 0.7140 - val_loss: 1.2757 - bleu: 26.7914 - gen_len: 41.4932
<keras.callbacks.History at 0x7f4af02c52d0>
If you used the callback above, you can now share this model with all your friends, family or favorite pets: they can all load it with the identifier "your-username/the-name-you-picked"
so for instance:
from transformers import TFAutoModelForSeq2SeqLM
model = TFAutoModelForSeq2SeqLM.from_pretrained("your-username/my-awesome-model")
Now we've trained our model, let's see how we could load it and use it to translate text in future! First, let's load it from the hub. This means we can resume the code from here without needing to rerun everything above every time.
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
# You can of course substitute your own username and model here if you've trained and uploaded it!
model_name = 'Rocketknight1/opus-mt-en-ROMANCE-finetuned-en-to-ro'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
Downloading tokenizer_config.json: 0%| | 0.00/551 [00:00<?, ?B/s]
Downloading source.spm: 0%| | 0.00/761k [00:00<?, ?B/s]
Downloading target.spm: 0%| | 0.00/780k [00:00<?, ?B/s]
Downloading vocab.json: 0%| | 0.00/1.51M [00:00<?, ?B/s]
Downloading special_tokens_map.json: 0%| | 0.00/74.0 [00:00<?, ?B/s]
Downloading config.json: 0%| | 0.00/1.47k [00:00<?, ?B/s]
Downloading tf_model.h5: 0%| | 0.00/298M [00:00<?, ?B/s]
2022-07-26 17:56:38.238360: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:38.275342: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:38.276357: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:38.278209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-07-26 17:56:38.309314: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:38.310572: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:38.311790: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:39.033228: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:39.033908: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:39.034535: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2022-07-26 17:56:39.035152: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21719 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:21:00.0, compute capability: 8.6 2022-07-26 17:56:40.631257: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. All model checkpoint layers were used when initializing TFMarianMTModel. All the layers of TFMarianMTModel were initialized from the model checkpoint at Rocketknight1/opus-mt-en-ROMANCE-finetuned-en-to-ro. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMarianMTModel for predictions without further training.
Now let's try tokenizing some text and passing it to the model to generate a translation. Don't forget to add the "translate: " string at the start if you're using a T5
model.
input_text = "I'm not actually a very competent Romanian speaker, but let's try our best."
if 't5' in model_name:
input_text = "translate English to Romanian: " + input_text
tokenized = tokenizer([input_text], return_tensors='np')
out = model.generate(**tokenized, max_length=128)
print(out)
tf.Tensor( [[65000 642 1204 5 12648 35 26792 415 36773 5031 11008 208 2 1019 203 2836 600 229 15032 3796 13286 226 3 0 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000 65000]], shape=(1, 128), dtype=int32)
Well, that's some tokens and a lot of padding! Let's decode those to see what it says, using the skip_special_tokens
argument to skip those padding tokens:
with tokenizer.as_target_tokenizer():
print(tokenizer.decode(out[0], skip_special_tokens=True))
Nu sunt de fapt un vorbitor român foarte competent, dar haideţi să facem tot posibilul.
This is the point where I start wishing I'd done this example in a language I actually speak. Still, it looks good! Probably!
If you just want to generate a few translations, the code above is all you need. However, generation can be much faster if you use XLA, and if you want to generate data in bulk, you should probably use it! If you're using XLA, though, remember that you'll need to do a new XLA compilation for every input size you pass to the model. This means that you should keep your batch size constant, and consider padding inputs to the same length, or using pad_to_multiple_of
in your tokenizer to reduce the number of different input shapes you pass. Let's show an example of that:
import tensorflow as tf
@tf.function(jit_compile=True)
def generate(inputs):
return model.generate(**inputs, max_length=128)
tokenized_data = tokenizer([input_text], return_tensors="np", pad_to_multiple_of=128)
out = generate(tokenized_data)
2022-07-26 18:19:44.757209: I tensorflow/compiler/xla/service/dynamic_dimension_inference.cc:965] Reshaping a dynamic dimension into a scalar, which has undefined behavior when input size is 0. The offending instruction is: %reshape.41 = s32[] reshape(s32[<=1]{0} %set-dimension-size.3), metadata={op_type="Equal" op_name="cond/while/map/while/map/while/cond/cond/Equal" source_file="/home/matt/PycharmProjects/transformers/src/transformers/generation_tf_logits_process.py" source_line=351}
with tokenizer.as_target_tokenizer():
print(tokenizer.decode(out[0], skip_special_tokens=True))
Nu sunt de fapt un vorbitor român foarte competent, dar haideţi să facem tot posibilul.
The pipeline API offers a convenient shortcut for all of this, but doesn't (yet!) support XLA generation:
from transformers import pipeline
translator = pipeline('text2text-generation', model_name, framework="tf")
All model checkpoint layers were used when initializing TFMarianMTModel. All the layers of TFMarianMTModel were initialized from the model checkpoint at Rocketknight1/opus-mt-en-ROMANCE-finetuned-en-to-ro. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMarianMTModel for predictions without further training.
translator(input_text, max_length=128)
[{'generated_text': 'Nu sunt de fapt un vorbitor român foarte competent, dar haideţi să facem tot posibilul.'}]
Easy!