This notebook regroups the code sample of the video below, which is a part of the Hugging Face course.
#@title
from IPython.display import HTML
HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/AUozVp78dhk?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
Install the Transformers and Datasets libraries to run this notebook.
! pip install datasets transformers[sentencepiece]
This notebook regroups the code sample of the video below, which is a part of the Hugging Face course.
#@title
from IPython.display import HTML
HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/alq1l8Lv9GA?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
Install the Transformers and Datasets libraries to run this notebook.
! pip install datasets transformers[sentencepiece]
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_dataset(dataset):
encoded = tokenizer(
dataset["sentence1"],
dataset["sentence2"],
max_length=128,
truncation=True,
)
return encoded.data
tokenized_datasets = raw_datasets.map(tokenize_dataset, batched=True)
train_dataset = tokenized_datasets["train"].to_tf_dataset(
columns=["input_ids", "attention_mask", "token_type_ids"],
label_cols=["label"],
shuffle=True,
batch_size=8)
validation_dataset = tokenized_datasets["validation"].to_tf_dataset(
columns=["input_ids", "attention_mask", "token_type_ids"],
label_cols=["label"],
shuffle=True,
batch_size=8)
next(iter(train_dataset))[1]
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification
checkpoint = 'bert-base-cased'
model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss)
model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=3
)