transformers
🤗 library¶This notebook demonstrates how to use the Segment Anything Model (SAM) to automatically generate segementation masks on any image. The model was released by Meta AI in the paper Segment Anything Model. The original source code can be found here
The mask-generation
pipeline, freshly released for SAM, creates a gris of 1024
which are feed in a batch of points_per_batch
to the model. The examples are inspired from the original notebook of the authors.
!pip install -q git+https://github.com/huggingface/transformers.git
Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 200.1/200.1 kB 16.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 104.4 MB/s eta 0:00:00 Building wheel for transformers (pyproject.toml) ... done
Run the cells below to import the needed utility functions for displaying the masks!
import numpy as np
import matplotlib.pyplot as plt
import gc
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
del mask
gc.collect()
def show_masks_on_image(raw_image, masks):
plt.imshow(np.array(raw_image))
ax = plt.gca()
ax.set_autoscale_on(False)
for mask in masks:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
del mask
gc.collect()
Use the from_pretrained
method on the SamForMaskGeneration
class to load the model from the Hub! For the sake of this demonstration we will use the vit-huge
checkpoint.
from transformers import pipeline
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)
Downloading (…)lve/main/config.json: 0%| | 0.00/6.53k [00:00<?, ?B/s]
Downloading pytorch_model.bin: 0%| | 0.00/2.56G [00:00<?, ?B/s]
Downloading (…)rocessor_config.json: 0%| | 0.00/466 [00:00<?, ?B/s]
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
plt.imshow(raw_image)
<matplotlib.image.AxesImage at 0x7f8a4a198760>
Let's automatically generate the masks on the image! For that simply pass the raw image into the generator
outputs = generator(raw_image, points_per_batch=64)
The line above you take ~7 seconds on Google Colab 1xNVIDIA-T4, now let's see the resulting segmentation masks.
masks = outputs["masks"]
show_masks_on_image(raw_image, masks)
You can feed both urls and raw images. Here is an example:
new_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/depth-estimation-example.jpg"
outputs = generator([raw_image,new_image_url], points_per_batch=64)
masks = outputs[1]["masks"]
raw_image = Image.open(requests.get(new_image_url, stream=True).raw).convert("RGB")
show_masks_on_image(raw_image, masks)