#@title
from IPython.display import HTML
HTML('')
! pip install datasets transformers[sentencepiece]
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("xsum")
raw_datasets = raw_datasets.remove_columns(["id"])
raw_datasets["train"]
print(raw_datasets["train"][1])
from transformers import AutoTokenizer
model_checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
sample = raw_datasets["train"][1]
inputs = tokenizer(sample["document"])
with tokenizer.as_target_tokenizer():
targets = tokenizer(sample["summary"])
print(tokenizer.convert_ids_to_tokens(inputs["input_ids"]))
print(tokenizer.convert_ids_to_tokens(targets["input_ids"])
max_input_length = 1024
max_target_length = 128
def preprocess_function(examples):
model_inputs = tokenizer(examples["document"], max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets = raw_datasets.map(
preprocess_function, batched=True, remove_columns=["document", "summary"]
)
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)