Install the Transformers, Datasets, and Evaluate libraries to run this notebook.
!pip install datasets evaluate transformers[sentencepiece]
!apt install git-lfs
You will need to setup git, adapt your email and name in the following cell.
!git config --global user.email "you@example.com"
!git config --global user.name "Your Name"
You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials.
from huggingface_hub import notebook_login
notebook_login()
from datasets import load_dataset
raw_datasets = load_dataset("kde4", lang1="en", lang2="fr")
raw_datasets
DatasetDict({ train: Dataset({ features: ['id', 'translation'], num_rows: 210173 }) })
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=20)
split_datasets
DatasetDict({ train: Dataset({ features: ['id', 'translation'], num_rows: 189155 }) test: Dataset({ features: ['id', 'translation'], num_rows: 21018 }) })
split_datasets["validation"] = split_datasets.pop("test")
split_datasets["train"][1]["translation"]
{'en': 'Default to expanded threads', 'fr': 'Par défaut, développer les fils de discussion'}
from transformers import pipeline
model_checkpoint = "Helsinki-NLP/opus-mt-en-fr"
translator = pipeline("translation", model=model_checkpoint)
translator("Default to expanded threads")
[{'translation_text': 'Par défaut pour les threads élargis'}]
split_datasets["train"][172]["translation"]
{'en': 'Unable to import %1 using the OFX importer plugin. This file is not the correct format.', 'fr': "Impossible d'importer %1 en utilisant le module d'extension d'importation OFX. Ce fichier n'a pas un format correct."}
translator(
"Unable to import %1 using the OFX importer plugin. This file is not the correct format."
)
[{'translation_text': "Impossible d'importer %1 en utilisant le plugin d'importateur OFX. Ce fichier n'est pas le bon format."}]
from transformers import AutoTokenizer
model_checkpoint = "Helsinki-NLP/opus-mt-en-fr"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")
en_sentence = split_datasets["train"][1]["translation"]["en"]
fr_sentence = split_datasets["train"][1]["translation"]["fr"]
inputs = tokenizer(en_sentence, text_target=fr_sentence)
inputs
{'input_ids': [47591, 12, 9842, 19634, 9, 0], 'attention_mask': [1, 1, 1, 1, 1, 1], 'labels': [577, 5891, 2, 3184, 16, 2542, 5, 1710, 0]}
wrong_targets = tokenizer(fr_sentence)
print(tokenizer.convert_ids_to_tokens(wrong_targets["input_ids"]))
print(tokenizer.convert_ids_to_tokens(inputs["labels"]))
['▁Par', '▁dé', 'f', 'aut', ',', '▁dé', 've', 'lop', 'per', '▁les', '▁fil', 's', '▁de', '▁discussion', '</s>'] ['▁Par', '▁défaut', ',', '▁développer', '▁les', '▁fils', '▁de', '▁discussion', '</s>']
max_length = 128
def preprocess_function(examples):
inputs = [ex["en"] for ex in examples["translation"]]
targets = [ex["fr"] for ex in examples["translation"]]
model_inputs = tokenizer(
inputs, text_target=targets, max_length=max_length, truncation=True
)
return model_inputs
tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
remove_columns=split_datasets["train"].column_names,
)
from transformers import TFAutoModelForSeq2SeqLM
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_pt=True)
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
batch = data_collator([tokenized_datasets["train"][i] for i in range(1, 3)])
batch.keys()
dict_keys(['attention_mask', 'input_ids', 'labels', 'decoder_input_ids'])
batch["labels"]
tensor([[ 577, 5891, 2, 3184, 16, 2542, 5, 1710, 0, -100, -100, -100, -100, -100, -100, -100], [ 1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649, 0]])
batch["decoder_input_ids"]
tensor([[59513, 577, 5891, 2, 3184, 16, 2542, 5, 1710, 0, 59513, 59513, 59513, 59513, 59513, 59513], [59513, 1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649]])
for i in range(1, 3):
print(tokenized_datasets["train"][i]["labels"])
[577, 5891, 2, 3184, 16, 2542, 5, 1710, 0] [1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649, 0]
tf_train_dataset = model.prepare_tf_dataset(
tokenized_datasets["train"],
collate_fn=data_collator,
shuffle=True,
batch_size=32,
)
tf_eval_dataset = model.prepare_tf_dataset(
tokenized_datasets["validation"],
collate_fn=data_collator,
shuffle=False,
batch_size=16,
)
!pip install sacrebleu
import evaluate
metric = evaluate.load("sacrebleu")
predictions = [
"This plugin lets you translate web pages between several languages automatically."
]
references = [
[
"This plugin allows you to automatically translate web pages between several languages."
]
]
metric.compute(predictions=predictions, references=references)
{'score': 46.750469682990165, 'counts': [11, 6, 4, 3], 'totals': [12, 11, 10, 9], 'precisions': [91.67, 54.54, 40.0, 33.33], 'bp': 0.9200444146293233, 'sys_len': 12, 'ref_len': 13}
predictions = ["This This This This"]
references = [
[
"This plugin allows you to automatically translate web pages between several languages."
]
]
metric.compute(predictions=predictions, references=references)
{'score': 1.683602693167689, 'counts': [1, 0, 0, 0], 'totals': [4, 3, 2, 1], 'precisions': [25.0, 16.67, 12.5, 12.5], 'bp': 0.10539922456186433, 'sys_len': 4, 'ref_len': 13}
predictions = ["This plugin"]
references = [
[
"This plugin allows you to automatically translate web pages between several languages."
]
]
metric.compute(predictions=predictions, references=references)
{'score': 0.0, 'counts': [2, 1, 0, 0], 'totals': [2, 1, 0, 0], 'precisions': [100.0, 100.0, 0.0, 0.0], 'bp': 0.004086771438464067, 'sys_len': 2, 'ref_len': 13}
import numpy as np
import tensorflow as tf
from tqdm import tqdm
generation_data_collator = DataCollatorForSeq2Seq(
tokenizer, model=model, return_tensors="tf", pad_to_multiple_of=128
)
tf_generate_dataset = model.prepare_tf_dataset(
tokenized_datasets["validation"],
collate_fn=generation_data_collator,
shuffle=False,
batch_size=8,
)
@tf.function(jit_compile=True)
def generate_with_xla(batch):
return model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_new_tokens=128,
)
def compute_metrics():
all_preds = []
all_labels = []
for batch, labels in tqdm(tf_generate_dataset):
predictions = generate_with_xla(batch)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels = labels.numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
all_preds.extend(decoded_preds)
all_labels.extend(decoded_labels)
result = metric.compute(predictions=all_preds, references=all_labels)
return {"bleu": result["score"]}
from huggingface_hub import notebook_login
notebook_login()
print(compute_metrics())
from transformers import create_optimizer
from transformers.keras_callbacks import PushToHubCallback
import tensorflow as tf
# The number of training steps is the number of samples in the dataset, divided by the batch size then multiplied
# by the total number of epochs. Note that the tf_train_dataset here is a batched tf.data.Dataset,
# not the original Hugging Face Dataset, so its len() is already num_samples // batch_size.
num_epochs = 3
num_train_steps = len(tf_train_dataset) * num_epochs
optimizer, schedule = create_optimizer(
init_lr=5e-5,
num_warmup_steps=0,
num_train_steps=num_train_steps,
weight_decay_rate=0.01,
)
model.compile(optimizer=optimizer)
# Train in mixed-precision float16
tf.keras.mixed_precision.set_global_policy("mixed_float16")
from transformers.keras_callbacks import PushToHubCallback
callback = PushToHubCallback(
output_dir="marian-finetuned-kde4-en-to-fr", tokenizer=tokenizer
)
model.fit(
tf_train_dataset,
validation_data=tf_eval_dataset,
callbacks=[callback],
epochs=num_epochs,
)
print(compute_metrics())
from transformers import pipeline
# Replace this with your own checkpoint
model_checkpoint = "huggingface-course/marian-finetuned-kde4-en-to-fr"
translator = pipeline("translation", model=model_checkpoint)
translator("Default to expanded threads")
[{'translation_text': 'Par défaut, développer les fils de discussion'}]
translator(
"Unable to import %1 using the OFX importer plugin. This file is not the correct format."
)
[{'translation_text': "Impossible d'importer %1 en utilisant le module externe d'importation OFX. Ce fichier n'est pas le bon format."}]