This notebook is Part 3 of the enrichment notebook series where we utilize various zero-shot models to enrich the metadata of existing dataset.
If you haven't checkout out Part 1 and Part 2 we highly encourage you to go them first before proceeding with this notebook.
In this notebook we show an end-to-end example on how you can enrich the metadata of your visual using open-source zero-shot models image segmentation model Segment Anything (SAM).
By the end of this notebook, you'll learn how to:
First, let's install the necessary packages:
🗒 Note - We highly recommending running this notebook in CUDA enabled environment to reduce the run time.
!pip install -Uq fastdup git+https://github.com/facebookresearch/segment-anything.git gdown
Now, test the installation. If there's no error message, we are ready to go.
import fastdup
fastdup.__version__
'1.57'
Download the coco-minitrain dataset - A curated mini training set consisting of 20% of COCO 2017 training dataset. The coco-minitrain consists of 25,000 images and annotations.
!gdown --fuzzy https://drive.google.com/file/d/1iSXVTlkV1_DhdYpVDqsjlT4NJFQ7OkyK/view
!unzip -qq coco_minitrain_25k.zip
In addition to the zero-shot recognition and detection modes, fastdup also supports zero-shot segmentation using the Segment Anything Model (SAM) from MetaAI.
SAM produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image.
In Part 2 of the enrichment notebook series, we utilized Grounding DINO as a zero-shot detection model and ran an inference over the images in our dataset.
We ended up with a DataFrame consisting of filename
, ram_tags
, grounding_dino_bboxes
, grounding_dino_scores
and grounding_dino_labels
column as follows.
import pandas as pd
# Dataframe we got from Part 2
data = {
'filename': [
'coco_minitrain_25k/images/val2017/000000382734.jpg',
'coco_minitrain_25k/images/val2017/000000508730.jpg',
'coco_minitrain_25k/images/val2017/000000202339.jpg',
],
'ram_tags': [
'bath . bathroom . doorway . drain . floor . glass door . room . screen door . shower . white',
'baby . bathroom . bathroom accessory . bin . boy . brush . chair . child . comb . diaper . hair . hairbrush . play . potty . sit . stool . tile wall . toddler . toilet bowl . toilet seat . toy',
'bus . bus station . business suit . carry . catch . city bus . pillar . man . shopping bag . sign . suit . tie . tour bus . walk',
],
'grounding_dino_bboxes': [
[(94.36, 479.79, 236.6, 589.37), (4.92, 3.73, 475.19, 637.36), (95.94, 514.92, 376.53, 638.46), (41.91, 37.47, 425.01, 637.09), (115.27, 602.26, 164.17, 635.21)],
[(3.58, 2.77, 635.13, 475.62), (30.91, 104.91, 301.75, 476.29), (68.59, 105.02, 266.22, 267.8), (359.26, 116.82, 576.6, 475.9), (374.37, 116.77, 557.19, 254.07), (466.9, 0.71, 638.7, 117.05), (266.95, 433.87, 291.04, 476.78), (466.53, 349.26, 525.87, 405.73), (350.62, 272.66, 571.98, 476.46)],
[(73.28, 256.74, 135.63, 374.42), (103.53, 105.23, 267.7, 410.18), (98.31, 33.85, 271.8, 434.72), (203.78, 63.88, 463.32, 298.29), (147.5, 106.62, 163.49, 172.9), (164.1, 52.93, 272.88, 152.68), (0.49, 0.76, 82.86, 333.41), (1.96, 2.22, 477.75, 636.07), (398.15, 281.2, 479.01, 545.03), (147.02, 106.66, 163.66, 227.86), (400.67, 98.89, 476.48, 318.45), (165.71, 52.9, 372.94, 185.69)]
],
'grounding_dino_scores': [
[0.5789, 0.3895, 0.4444, 0.3018, 0.3601],
[0.5898, 0.3738, 0.3679, 0.3641, 0.362, 0.3482, 0.3804, 0.3755, 0.3742],
[0.5325, 0.4582, 0.4429, 0.4012, 0.365, 0.3587, 0.3338, 0.3322, 0.3212, 0.3168, 0.3056, 0.2986]
],
'grounding_dino_labels': [
['bath', 'bathroom', 'floor', 'glass door', 'drain'],
['bathroom', 'toddler', 'hair', 'toddler', 'hair', 'bathroom accessory', 'hairbrush', 'diaper', 'chair'],
['man', 'bus', 'shopping bag', 'bus station', 'business suit', 'city bus', 'pillar', 'sign', 'tour bus', 'carry', 'catch', 'walk'],
]
}
# Create the DataFrame
df = pd.DataFrame(data)
df
filename | ram_tags | grounding_dino_bboxes | grounding_dino_scores | grounding_dino_labels | |
---|---|---|---|---|---|
0 | coco_minitrain_25k/images/val2017/000000382734.jpg | bath . bathroom . doorway . drain . floor . glass door . room . screen door . shower . white | [(94.36, 479.79, 236.6, 589.37), (4.92, 3.73, 475.19, 637.36), (95.94, 514.92, 376.53, 638.46), (41.91, 37.47, 425.01, 637.09), (115.27, 602.26, 164.17, 635.21)] | [0.5789, 0.3895, 0.4444, 0.3018, 0.3601] | [bath, bathroom, floor, glass door, drain] |
1 | coco_minitrain_25k/images/val2017/000000508730.jpg | baby . bathroom . bathroom accessory . bin . boy . brush . chair . child . comb . diaper . hair . hairbrush . play . potty . sit . stool . tile wall . toddler . toilet bowl . toilet seat . toy | [(3.58, 2.77, 635.13, 475.62), (30.91, 104.91, 301.75, 476.29), (68.59, 105.02, 266.22, 267.8), (359.26, 116.82, 576.6, 475.9), (374.37, 116.77, 557.19, 254.07), (466.9, 0.71, 638.7, 117.05), (266.95, 433.87, 291.04, 476.78), (466.53, 349.26, 525.87, 405.73), (350.62, 272.66, 571.98, 476.46)] | [0.5898, 0.3738, 0.3679, 0.3641, 0.362, 0.3482, 0.3804, 0.3755, 0.3742] | [bathroom, toddler, hair, toddler, hair, bathroom accessory, hairbrush, diaper, chair] |
2 | coco_minitrain_25k/images/val2017/000000202339.jpg | bus . bus station . business suit . carry . catch . city bus . pillar . man . shopping bag . sign . suit . tie . tour bus . walk | [(73.28, 256.74, 135.63, 374.42), (103.53, 105.23, 267.7, 410.18), (98.31, 33.85, 271.8, 434.72), (203.78, 63.88, 463.32, 298.29), (147.5, 106.62, 163.49, 172.9), (164.1, 52.93, 272.88, 152.68), (0.49, 0.76, 82.86, 333.41), (1.96, 2.22, 477.75, 636.07), (398.15, 281.2, 479.01, 545.03), (147.02, 106.66, 163.66, 227.86), (400.67, 98.89, 476.48, 318.45), (165.71, 52.9, 372.94, 185.69)] | [0.5325, 0.4582, 0.4429, 0.4012, 0.365, 0.3587, 0.3338, 0.3322, 0.3212, 0.3168, 0.3056, 0.2986] | [man, bus, shopping bag, bus station, business suit, city bus, pillar, sign, tour bus, carry, catch, walk] |
If you'd like to reproduce the above dataframe, Part 2 notebook details the code you need to run.
Similar to all previous examples, you can use the enrich
method to add masks to your DataFrame of images.
In the following code snippet, we load the SAM model and specify input_col='grounding_dino_bboxes'
to allow SAM to use the bounding boxes as inputs.
fd = fastdup.create(input_dir='./coco_minitrain_25k')
df = fd.enrich(task='zero-shot-segmentation',
model='segment-anything',
input_df=df,
input_col='grounding_dino_bboxes'
)
Warning: fastdup create() without work_dir argument, output is stored in a folder named work_dir in your current working path.
INFO:fastdup.model.sam:Loading model checkpoint from - /home/dnth/sam_vit_h_4b8939.pth
Next, drop rows in the DataFrame without masks for the purpose of visualization.
df.dropna(subset=['sam_masks'], inplace=True)
Plot the images with bounding boxes and masks.
from fastdup.models_utils import plot_annotations
plot_annotations(df,
image_col='filename',
tags_col='ram_tags',
bbox_col='grounding_dino_bboxes',
scores_col='grounding_dino_scores',
labels_col='grounding_dino_labels',
masks_col='sam_masks'
)
To run an inference using the SAM model, import the SegmentAnythingModel
class and provide an image-bounding box pair as the input.
from IPython.display import Image
Image("coco_minitrain_25k/images/val2017/000000449996.jpg")
from fastdup.models_sam import SegmentAnythingModel
import torch
model = SegmentAnythingModel()
result = model.run_inference(image_path="coco_minitrain_25k/images/val2017/000000449996.jpg",
bboxes=torch.tensor((1.47, 1.45, 638.46, 241.37))) # bounding box of the sky
INFO:fastdup.model.sam:Loading model checkpoint from - /home/dnth/sam_vit_h_4b8939.pth
The result is a of mask of the object based on the given bounding box.
result.shape
torch.Size([1, 1, 428, 640])
result
tensor([[[[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [ True, True, True, ..., True, True, True], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]]]], device='cuda:0')
Lets plot an overlay of the mask and image.
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
from PIL import Image
import torch
# Image
image_path = "coco_minitrain_25k/images/val2017/000000449996.jpg"
pil_image = Image.open(image_path)
image_np = np.array(pil_image)
# Bbox
bbox = torch.tensor((1.47, 1.45, 638.46, 241.37)) # Replace with your actual bounding box tensor
xmin, ymin, xmax, ymax = bbox
# Mask
# Squeeze out the first two dimensions to make it 2D
mask_2d = result.cpu().squeeze(0).squeeze(0)
plt.imshow(image_np)
plt.imshow(mask_2d, cmap='jet', alpha=0.5)
plt.gca().add_patch(Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=2.5, edgecolor='limegreen', facecolor='none'))
plt.axis('off')
plt.show()
You can also load other variants of SAM from the official SAM repo or even your own custom model.
To do so, download the sam_vit_b
weights and the sam_vit_l
weights from into your local folder and load them into the constructor as follows.
model = SegmentAnythingModel(model_weights="sam_vit_b_01ec64.pth", model_type="vit_b")
model = SegmentAnythingModel(model_weights="sam_vit_l_0b3195.pth", model_type="vit_l")
In this tutorial, we showed how you can run Segment Anything Model as a zero-shot segmentation model to enrich your dataset.
This notebook is Part 3 of the dataset enrichment notebook series where we utilize various zero-shot models to enrich datasets.
Questions about this tutorial? Reach out to us on our Slack channel!
Next, feel free to check out other tutorials -
If you prefer a no-code platform to inspect and visualize your dataset, try our free cloud product VL Profiler - VL Profiler is our first no-code commercial product that lets you visualize and inspect your dataset in your browser.
VL Profiler is free to get started. Upload up to 1,000,000 images for analysis at zero cost!
Sign up now.
As usual, feedback is welcome! Questions? Drop by our Slack channel or open an issue on GitHub.