This notebook shows how to fine-tune a pre-trained Vision model for Video Classification on a custom dataset. The idea is to add a randomly initialized classification head on top of a pre-trained encoder and fine-tune the model altogether on a labeled dataset.
This notebook uses a subset of the UCF-101 dataset. We'll be using a subset of the dataset to keep the runtime of the tutorial short. The subset was prepared using this notebook following this guide.
We'll fine-tune the VideoMAE model, which was pre-trained on the Kinetics 400 dataset. You can find the other variants of VideoMAE available on 🤗 Hub here. You can also extend this notebook to use other video models such as X-CLIP.
Note that for models where there's no classification head already available you'll have to manually attach it (randomly initialized). But this is not the case for VideoMAE since we already have a VideoMAEForVideoClassification
class.
This notebook leverages TorchVision's and PyTorchVideo's transforms for applying data preprocessing transformations including data augmentation.
Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.
model_ckpt = "MCG-NJU/videomae-base" # pre-trained model from which to fine-tune
batch_size = 8 # batch size for training and evaluation
Before we start, let's install the pytorchvideo
, transformers
, and evaluate
libraries.
!pip install pytorchvideo transformers evaluate -q
If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.
To be able to share your model with the community, there are a few more steps to follow.
First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your token:
from huggingface_hub import notebook_login
notebook_login()
Then you need to install Git-LFS to upload your model checkpoints:
!git config --global credential.helper store
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
from transformers.utils import send_example_telemetry
send_example_telemetry("video_classification_notebook", framework="pytorch")
In this notebook, we will see how to fine-tune one of the 🤗 Transformers vision models on a Video Classification dataset.
Given a video, the goal is to predict an appropriate class for it, like "archery".
Here we first download the subset archive and un-archive it.
from huggingface_hub import hf_hub_download
hf_dataset_identifier = "sayakpaul/ucf101-subset"
filename = "UCF101_subset.tar.gz"
file_path = hf_hub_download(
repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset"
)
!tar xf {file_path}
Now, let's investigate what is inside the archive.
dataset_root_path = "UCF101_subset"
!find {dataset_root_path} | head -5
Broadly, dataset_root_path
is organized like so:
UCF101_subset/
train/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
val/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
test/
BandMarching/
video_1.mp4
video_2.mp4
...
Archery
video_1.mp4
video_2.mp4
...
...
Let's now count the number of total videos we have.
import pathlib
dataset_root_path = pathlib.Path(dataset_root_path)
video_count_train = len(list(dataset_root_path.glob("train/*/*.avi")))
video_count_val = len(list(dataset_root_path.glob("val/*/*.avi")))
video_count_test = len(list(dataset_root_path.glob("test/*/*.avi")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")
all_video_file_paths = (
list(dataset_root_path.glob("train/*/*.avi"))
+ list(dataset_root_path.glob("val/*/*.avi"))
+ list(dataset_root_path.glob("test/*/*.avi"))
)
all_video_file_paths[:5]
The video paths, when sorted
, appear like so:
...
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c02.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c06.avi'
...
We notice that there are video clips belonging to the same group / scene where group is denoted by g
in the video file paths. v_ApplyEyeMakeup_g07_c04.avi
and v_ApplyEyeMakeup_g07_c06.avi
, for example.
For the validation and evaluation splits, we wouldn't want to have video clips from the same group / scene to prevent data leakage. The subset that we're using in this tutorial takes this information into account.
Next up, we derive the set of labels we have in the dataset. Let's also create two dictionaries that'll be helpful when initializing the model:
label2id
: maps the class names to integers.id2label
: maps the integers to class names.class_labels = sorted({str(path).split("/")[2] for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}
print(f"Unique classes: {list(label2id.keys())}.")
We've got 10 unique classes. For each class we have 30 videos in the training set.
In the next cell, we initialize a video classification model where the encoder is initialized with the pre-trained parameters and the classification head is randomly initialized. We also initialize the feature extractor associated to the model. This will come in handy during writing the preprocessing pipeline for our dataset.
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
model = VideoMAEForVideoClassification.from_pretrained(
model_ckpt,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
The warning is telling us we are throwing away some weights (e.g. the weights and bias of the classifier
layer) and randomly initializing some other (the weights and bias of a new classifier
layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.
Note that this checkpoint leads to better performance on this task as the checkpoint was obtained fine-tuning on a similar downstream task having considerable domain overlap. You can check out this checkpoint which was obtained by fine-tuning MCG-NJU/videomae-base-finetuned-kinetics
and it obtains much better performance.
For preprocessing the videos, we'll leverage the PyTorch Video library. We start by importing the dependencies we need.
import pytorchvideo.data
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
RandomShortSideScale,
RemoveKey,
ShortSideScale,
UniformTemporalSubsample,
)
from torchvision.transforms import (
Compose,
Lambda,
RandomCrop,
RandomHorizontalFlip,
Resize,
)
For the training dataset transformations, we use a combination of uniform temporal subsampling, pixel normalization, random cropping, and random horizontal flipping. For the validation and evaluation dataset transformations, we keep the transformation chain the same except for random cropping and horizontal flipping. To learn more about the details of these transformations check out the official documentation of PyTorch Video.
We'll use the image_processor
associated with the pre-trained model to obtain the following information:
import os
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Training dataset transformations.
train_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
RandomHorizontalFlip(p=0.5),
]
),
),
]
)
# Training dataset.
train_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "train"),
clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
decode_audio=False,
transform=train_transform,
)
# Validation and evaluation datasets' transformations.
val_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
Resize(resize_to),
]
),
),
]
)
# Validation and evaluation datasets.
val_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "val"),
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
decode_audio=False,
transform=val_transform,
)
test_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "test"),
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
decode_audio=False,
transform=val_transform,
)
Note: The above dataset pipelines are taken from the official PyTorch Video example. We're using the pytorchvideo.data.Ucf101()
function because it's tailored for the UCF-101 dataset. Under the hood, it returns a pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset
object. LabeledVideoDataset
class is the base class for all things video in the PyTorch Video dataset. So, if you wanted to use a custom dataset not supported off-the-shelf by PyTorch Video, you can extend the LabeledVideoDataset
class accordingly. Refer to the data
API documentation to learn more. Also, if your dataset follows a similar structure (as shown above), then using the pytorchvideo.data.Ucf101()
should work just fine.
# We can access the `num_videos` argument to know the number of videos we have in the
# dataset.
train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos
Let's now take a preprocessed video from the dataset and investigate it.
sample_video = next(iter(train_dataset))
sample_video.keys()
def investigate_video(sample_video):
"""Utility to investigate the keys present in a single video sample."""
for k in sample_video:
if k == "video":
print(k, sample_video["video"].shape)
else:
print(k, sample_video[k])
print(f"Video label: {id2label[sample_video[k]]}")
investigate_video(sample_video)
We can also visualize the preprocessed videos for easier debugging.
import imageio
import numpy as np
from IPython.display import Image
def unnormalize_img(img):
"""Un-normalizes the image pixels."""
img = (img * std) + mean
img = (img * 255).astype("uint8")
return img.clip(0, 255)
def create_gif(video_tensor, filename="sample.gif"):
"""Prepares a GIF from a video tensor.
The video tensor is expected to have the following shape:
(num_frames, num_channels, height, width).
"""
frames = []
for video_frame in video_tensor:
frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
frames.append(frame_unnormalized)
kargs = {"duration": 0.25}
imageio.mimsave(filename, frames, "GIF", **kargs)
return filename
def display_gif(video_tensor, gif_name="sample.gif"):
"""Prepares and displays a GIF from a video tensor."""
video_tensor = video_tensor.permute(1, 0, 2, 3)
gif_filename = create_gif(video_tensor, gif_name)
return Image(filename=gif_filename)
video_tensor = sample_video["video"]
display_gif(video_tensor)
We'll leverage Trainer
from 🤗 Transformers for training the model. To instantiate a Trainer
, we will need to define the training configuration and an evaluation metric. The most important is the TrainingArguments
, which is a class that contains all the attributes to configure the training. It requires an output folder name, which will be used to save the checkpoints of the model. It also helps sync all the information in the model repository on 🤗 Hub.
Most of the training arguments are pretty self-explanatory, but one that is quite important here is remove_unused_columns=False
. This one will drop any features not used by the model's call function. By default it's True
because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('video' in particular) in order to create pixel_values
(which is a mandatory key our model expects in its inputs).
from transformers import TrainingArguments, Trainer
model_name = model_ckpt.split("/")[-1]
new_model_name = f"{model_name}-finetuned-ucf101-subset"
num_epochs = 4
args = TrainingArguments(
new_model_name,
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=True,
max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
)
There's no need to define max_steps
when instantiating TrainingArguments
. Since the dataset returned by pytorchvideo.data.Ucf101()
doesn't implement the __len__()
method we had to specify max_steps
.
Next, we need to define a function for how to compute the metrics from the predictions, which will just use the metric
we'll load now. The only preprocessing we have to do is to take the argmax of our predicted logits:
import evaluate
metric = evaluate.load("accuracy")
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
"""Computes accuracy on a batch of predictions."""
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
A note on evaluation:
In the VideoMAE paper, the authors use the following evaluation strategy. They evaluate the model on several clips from test videos and apply different crops to those clips and report the aggregate score. However, in the interest of simplicity and brevity, we don't consider that in this tutorial.
We also define a collate_fn
, which will be used to batch examples together.
Each batch consists of 2 keys, namely pixel_values
and labels
.
import torch
def collate_fn(examples):
"""The collation function to be used by `Trainer` to prepare data batches."""
# permute to (num_frames, num_channels, height, width)
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
Then we just need to pass all of this along with our datasets to the Trainer
:
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
You might wonder why we pass along the image_processor
as a tokenizer when we already preprocessed our data. This is only to make sure the feature extractor configuration file (stored as JSON) will also be uploaded to the repo on the hub.
Now we can finetune our model by calling the train
method:
train_results = trainer.train()
We can check with the evaluate
method that our Trainer
did reload the best model properly (if it was not the last one):
trainer.evaluate(test_dataset)
trainer.save_model()
test_results = trainer.evaluate(test_dataset)
trainer.log_metrics("test", test_results)
trainer.save_metrics("test", test_results)
trainer.save_state()
You can now upload the result of the training to the Hub, just execute this instruction (note that the Trainer will automatically create a model card as well as Tensorboard logs - see the "Training metrics" tab - amazing isn't it?):
trainer.push_to_hub()
Now that our model is trained, let's use it to run inference on a video from test_dataset
.
Let's load the trained model checkpoint and fetch a video from test_dataset
.
trained_model = VideoMAEForVideoClassification.from_pretrained(new_model_name)
sample_test_video = next(iter(test_dataset))
investigate_video(sample_test_video)
We then prepare the video as a torch.Tensor
and run inference.
def run_inference(model, video):
"""Utility to run inference given a model and test video.
The video is assumed to be preprocessed already.
"""
# (num_frames, num_channels, height, width)
perumuted_sample_test_video = video.permute(1, 0, 2, 3)
inputs = {
"pixel_values": perumuted_sample_test_video.unsqueeze(0),
"labels": torch.tensor(
[sample_test_video["label"]]
), # this can be skipped if you don't have labels available.
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
model = model.to(device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
return logits
logits = run_inference(trained_model, sample_test_video["video"])
We can now check if the model got the prediction right.
display_gif(sample_test_video["video"])
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
And it looks like it got it right!
You can also use this model to bring in your own videos. Check out this Space to know more. The Space will also show you how to run inference for a single video file.
Now that you've learned to train a well-performing video classification model on a custom dataset here is some homework for you:
Don't forget to share your models with the community =)