#!/usr/bin/env python # coding: utf-8 # # Fine-tuning a BERT model with skorch and Hugging Face # In this notebook, we follow the fine-tuning guideline from [Hugging Face documentation](https://huggingface.co/docs/transformers/training). Please check it out if we you want to know more about BERT and fine-tuning. Here, we assume that you're familiar with the general ideas. # # You will learn how to: # - integrate the [Hugging Face transformers](https://huggingface.co/docs/transformers/index) library with skorch # - use skorch to fine-tune a BERT model on a text classification task # - use skorch with the [Hugging Face accelerate](https://huggingface.co/docs/accelerate/index) library for automatic mixed precision (AMP) training #
# # Run in Google Colab # # View source on GitHub
# The first part of the notebook requires hugginface `transformers` as an additional dependency. If you have not already installed it, you can do so like this: # # `python -m pip install transformers` # In[1]: import subprocess # Installation on Google Colab try: import google.colab subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers']) except ImportError: pass # ## Imports # In[2]: import numpy as np import torch from sklearn.datasets import fetch_20newsgroups from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from skorch import NeuralNetClassifier from skorch.callbacks import LRScheduler, ProgressBar from skorch.hf import HuggingfacePretrainedTokenizer from torch import nn from torch.optim.lr_scheduler import LambdaLR from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer # ## Parameters # Change the values below if you want to try out different model architectures and hyper-parameters. # In[3]: # Choose a tokenizer and BERT model that work together TOKENIZER = "distilbert-base-uncased" PRETRAINED_MODEL = "distilbert-base-uncased" # model hyper-parameters OPTMIZER = torch.optim.AdamW LR = 5e-5 MAX_EPOCHS = 3 CRITERION = nn.CrossEntropyLoss BATCH_SIZE = 8 # device DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # ## Data # In[4]: dataset = fetch_20newsgroups() # For this notebook, we're making use the 20 newsgroups dataset. It is a text classification dataset with 20 classes. A decent score would be to reach 89% accuracy out of sample. For more details, read the description below: # In[5]: print(dataset.DESCR.split('Usage')[0]) # In[6]: dataset.target_names # In[7]: X = dataset.data y = dataset.target # In[8]: X_train, X_test, y_train, y_test, = train_test_split(X, y, stratify=y, random_state=0) # In[9]: X_train[:2] # ## Prepare the training # We want to use a linear learning rate schedule that linearly decreases the learning rate during training. # In[10]: num_training_steps = MAX_EPOCHS * (len(X_train) // BATCH_SIZE + 1) def lr_schedule(current_step): factor = float(num_training_steps - current_step) / float(max(1, num_training_steps)) assert factor > 0 return factor # Next we wrap the BERT module itself inside a simple `nn.Module`. The only real work for us here is to load the pretrained model and to return the _logits_ from the model output. The rest of the outputs is not needed. # In[11]: class BertModule(nn.Module): def __init__(self, name, num_labels): super().__init__() self.name = name self.num_labels = num_labels self.reset_weights() def reset_weights(self): self.bert = AutoModelForSequenceClassification.from_pretrained( self.name, num_labels=self.num_labels ) def forward(self, **kwargs): pred = self.bert(**kwargs) return pred.logits # ### Tokenizer # We make use of `HuggingfacePretrainedTokenizer`, which is a wrapper that skorch provides to use the tokenizers from Hugging Face. In this instance, we use a tokenizer that was pretrained in conjunction with BERT. The tokenizer is automatically downloaded if not already present. More on Hugging Face tokenizers can be found [here](https://huggingface.co/docs/tokenizers/index). # ## Training # ### Putting it all togther # Now we can put together all the parts from above. There is nothing special going on here, we simply use an sklearn `Pipeline` to chain the `HuggingfacePretrainedTokenizer` and the neural net. Using skorch's `NeuralNetClassifier`, we make sure to pass the `BertModule` as the first argument and to set the number of labels based on `y_train`. The criterion is `CrossEntropyLoss` because we return the logits. Moreover, we make use of the learning rate schedule we defined above, and we add the `ProgressBar` callback to monitor our progress. # In[12]: pipeline = Pipeline([ ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)), ('net', NeuralNetClassifier( BertModule, module__name=PRETRAINED_MODEL, module__num_labels=len(set(y_train)), optimizer=OPTMIZER, lr=LR, max_epochs=MAX_EPOCHS, criterion=CRITERION, batch_size=BATCH_SIZE, iterator_train__shuffle=True, device=DEVICE, callbacks=[ LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'), ProgressBar(), ], )), ]) # Since we are using skorch, we could now take this pipeline to run a grid search or other kind of hyper-parameter sweep to figure out the best hyper-parameters for this model. E.g. we could try out a different BERT model or a different `max_length`. # ### Fitting # In[13]: torch.manual_seed(0) torch.cuda.manual_seed(0) torch.cuda.manual_seed_all(0) np.random.seed(0) # In[14]: get_ipython().run_line_magic('time', 'pipeline.fit(X_train, y_train)') # ### Evaluation # In[15]: get_ipython().run_cell_magic('time', '', 'with torch.inference_mode():\n y_pred = pipeline.predict(X_test)\n') # In[16]: accuracy_score(y_test, y_pred) # We can be happy with the results. We set ourselves the goal to reach or exceed 89% accuracy on the test set and we managed to do that. # ## Training with automatic mixed precision (AMP) # For this to work, you need: # - A GPU that is capable of mixed precision training # - The [accelerate library](https://huggingface.co/docs/accelerate/index), which you can install as: `python -m pip install 'accelerate>=0.11'`. # - skorch version 0.12 or installed from the current master branch (`python -m pip install git+https://github.com/skorch-dev/skorch.git`) # # Again, we assume that you're familiar with the general concept of mixed precision training. For more information on how skorch integrates with accelerate, please consult the [skorch docs](https://skorch.readthedocs.io/en/latest/user/huggingface.html#accelerate). # In[17]: import subprocess subprocess.run(['python', '-m', 'pip', 'install', 'accelerate>=0.11']) # In[18]: from accelerate import Accelerator from skorch.hf import AccelerateMixin # In[19]: class AcceleratedNet(AccelerateMixin, NeuralNetClassifier): """NeuralNetClassifier with accelerate support""" # In[20]: accelerator = Accelerator(mixed_precision='fp16') # In[21]: pipeline2 = Pipeline([ ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)), ('net', AcceleratedNet( # <= changed BertModule, accelerator=accelerator, # <= changed module__name=PRETRAINED_MODEL, module__num_labels=len(set(y_train)), optimizer=OPTMIZER, lr=LR, max_epochs=MAX_EPOCHS, criterion=CRITERION, batch_size=BATCH_SIZE, iterator_train__shuffle=True, # device=DEVICE, # <= changed callbacks=[ LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'), ProgressBar(), ], )), ]) # In[22]: torch.manual_seed(0) torch.cuda.manual_seed(0) torch.cuda.manual_seed_all(0) np.random.seed(0) # In[23]: pipeline2.fit(X_train, y_train) # In[24]: get_ipython().run_cell_magic('time', '', 'with torch.inference_mode():\n y_pred = pipeline2.predict(X_test)\n') # In[25]: accuracy_score(y_test, y_pred) # Using AMP, we could reduce our training and prediction time by half, while attaining the same scores.