Based on this blog post, we show how with only a few lines of custom code, we can fine-tune a Vision Transformer model for a classification task.
In addition to installing torch and skorch, you need the transformers
and datasets
libraries:
$ python -m pip install transformers datasets
! [ ! -z "$COLAB_GPU" ] && pip install torch skorch transformers datasets
from functools import partial
import numpy as np
import torch
from datasets import load_dataset
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
from skorch.callbacks import ProgressBar, LRScheduler
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import ViTFeatureExtractor, ViTForImageClassification
torch.manual_seed(1234)
<torch._C.Generator at 0x7fc53c52f570>
More details on the dataset can be found on its datasets page. For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image.
ds = load_dataset('beans')
Found cached dataset beans (/home/vinh/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)
0%| | 0/3 [00:00<?, ?it/s]
X = ds['train']['image']
y = np.array(ds['train']['labels'])
ds['train'][0]['image']
We wrap the vision transformer feature extractor into an sklearn Transformer
. It doesn't do much more than loading the feature extractor and returning the pixel values of the features. It also takes care of setting the device.
The reason to have a separate step for the feature extractor is that it needs to be called on the images only once, given that the output is deterministic. If we would put it inside the nn.Module
, we would call it on the same data once per epoch, which is wasteful.
class FeatureExtractor(BaseEstimator, TransformerMixin):
def __init__(self, model_name, device='cpu'):
self.model_name = model_name
self.device = device
def fit(self, X, y=None, **fit_params):
self.extractor_ = ViTFeatureExtractor.from_pretrained(
self.model_name, device=self.device,
)
return self
def transform(self, X):
return self.extractor_(X, return_tensors='pt')['pixel_values']
The vision transformer module itself is modified to return the logits.
class VitModule(nn.Module):
def __init__(self, model_name, num_classes):
super().__init__()
self.model = ViTForImageClassification.from_pretrained(
model_name, num_labels=num_classes
)
def forward(self, X):
X = self.model(X)
return X.logits
To stick close to the original blog post, we use the same learning rate schedule.
def lr_lambda(current_step: int, num_warmup_steps, num_training_steps):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
vit_model = 'google/vit-base-patch32-224-in21k'
max_epochs = 4
batch_size = 16
optimizer = torch.optim.AdamW
learning_rate = 2e-4
weight_decay = 0.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr_lambda_schedule = partial(lr_lambda, num_warmup_steps=0.0, num_training_steps=max_epochs)
The model definition is straightforward. We use an sklearn Pipeline
to chain the feature extractor and the model together. The model itself is a skorch NeuralNetClassifier
, because we're dealing with a classification task. The module we need to pass to NeuralNetClassifier
is the VitModule
we defined above.
As always in skorch, to pass sub-parameters, we use the double-underscore notation. So e.g. to pass the number of classes argument, num_classes
, to the module, we set module__num_classes=3
.
The arguments used here are all fairly standard. Note that we use the LRScheduler
callback from skorch to use the aforementioned learning rate schedule, and we add the ProgressBar
callback too, which, as the name suggests, adds a progressbar.
To stick close to the blog post, we also set train_split=False
, so that skorch uses the whole training data for training. By default, we would instead split a part of the training data for internal validation. But since the dataset already defines a validation split, this is not necessary.
pipe = Pipeline([
('feature_extractor', FeatureExtractor(
vit_model,
device=device,
)),
('net', NeuralNetClassifier(
VitModule,
module__model_name=vit_model,
module__num_classes=3,
criterion=nn.CrossEntropyLoss,
max_epochs=max_epochs,
batch_size=batch_size,
optimizer=optimizer,
optimizer__weight_decay=weight_decay,
lr=learning_rate,
device=device,
iterator_train__shuffle=True,
train_split=False,
callbacks=[
LRScheduler(LambdaLR, lr_lambda=lr_lambda_schedule),
ProgressBar(),
],
)),
])
Now we're ready to train the model. As always, we just call fit
with our training data. skorch will automatically show the progress bar and some training metrics like the train loss.
pipe.fit(X, y);
Downloading: 0%| | 0.00/160 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/502 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/352M [00:00<?, ?B/s]
Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight'] - This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch32-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
0%| | 0/65 [00:00<?, ?it/s]
epoch train_loss lr dur
------- ------------ ------ -------
1 0.4265 0.0002 10.5292
0%| | 0/65 [00:00<?, ?it/s]
2 0.1479 0.0002 8.3317
0%| | 0/65 [00:00<?, ?it/s]
3 0.0523 0.0001 7.9842
0%| | 0/65 [00:00<?, ?it/s]
4 0.0195 0.0001 8.1936
Finally, let's calculate accuracy on the predefined validation set of the dataset.
X_valid = ds['validation']['image']
y_valid = ds['validation']['labels']
y_pred = pipe.predict(X_valid)
print(f"Accuracy on validation dataset is {accuracy_score(y_valid, y_pred):.3f}")
Accuracy on validation dataset is 0.985
The accuracy is very high, we can be happy with the results 🤗
The model training process shown in this notebook can easily be turned into a training script with a rich command line interface thanks to skorch's helper functions. Take a look here.