Florence-2 is a lightweight vision-language foundation model developed by Microsoft Azure AI and open-sourced under the MIT license. It aims to achieve a unified, prompt-based representation for diverse vision and vision-language tasks, including captioning, object detection, grounding, and segmentation. Despite its compact size, Florence-2 rivals much larger models like Kosmos-2 in performance. Florence-2 represents a significant advancement in vision-language models by combining lightweight architecture with robust capabilities, making it highly accessible and versatile. Its unified representation approach, supported by the extensive FLD-5B dataset, enables it to excel in multiple vision tasks without the need for separate models. This efficiency makes Florence-2 a strong contender for real-world applications, particularly on devices with limited resources.
More details about model can be found in model's resources collection and original paper
In this tutorial we consider how to convert and run Florence2 using OpenVINO.
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
import platform
%pip install -q "einops" "torch>2.1" "torchvision" "matplotlib>=3.4" "timm>=0.9.8" "transformers==4.51.3" "pillow" "gradio>=4.19" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q -U --pre "openvino>=2025.0"
if platform.system() == "Darwin":
%pip install -q "numpy<2.0.0"
import requests
from pathlib import Path
if not Path("ov_florence2_helper.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/florence2/ov_florence2_helper.py")
open("ov_florence2_helper.py", "w").write(r.text)
if not Path("gradio_helper.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/florence2/gradio_helper.py")
open("gradio_helper.py", "w").write(r.text)
if not Path("notebook_utils.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py")
open("notebook_utils.py", "w").write(r.text)
# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry
from notebook_utils import collect_telemetry
collect_telemetry("florence2.ipynb")
The Florence-2 series consists of two models: Florence-2-base and Florence-2-large, with 0.23 billion and 0.77 billion parameters, respectively. Additionally, authors provide finetuned on collection of downstream tasks model versions. In this tutorial you can select one of available model. By default, we will use Florence-2-base-ft.
from ov_florence2_helper import convert_florence2, get_model_selector
model_selector = get_model_selector()
model_selector
2025-03-13 09:19:33.531499: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-03-13 09:19:33.545873: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1741843173.560851 3973215 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1741843173.565552 3973215 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-03-13 09:19:33.581746: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. /home/ea/work/py311/lib/python3.11/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`. warnings.warn(
Dropdown(description='Model:', options=('microsoft/Florence-2-base-ft', 'microsoft/Florence-2-base', 'microsof…
Florence2 is PyTorch model. OpenVINO supports PyTorch models via conversion to OpenVINO Intermediate Representation (IR). OpenVINO model conversion API should be used for these purposes. ov.convert_model
function accepts original PyTorch model instance and example input for tracing and returns ov.Model
representing this model in OpenVINO framework. Converted model can be used for saving on disk using ov.save_model
function or directly loading on device using core.complie_model
.
ov_florence2_helper.py
script contains helper function for model conversion, please check its content if you interested in conversion details.
To sum up above, model consists of 4 parts:
We will convert each part separately, then combine them in inference pipeline.
model_id = model_selector.value
model_path = Path(model_id.split("/")[-1])
# Uncomment the line to see conversion code
# ??convert_florence2
convert_florence2(model_id, model_path)
✅ microsoft/Florence-2-base-ft already converted and can be found in Florence-2-base-ft
from notebook_utils import device_widget
device = device_widget()
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
OvFlorence2Model
class defined in ov_florence2_helper.py
provides convenient way for running model. It accepts directory with converted model and inference device as arguments. For running model we will use generate
method.
from ov_florence2_helper import OVFlorence2Model
# Uncomment the line to see model class code
# ??OVFlorence2Model
model = OVFlorence2Model(model_path, device.value)
Additionally, for model usage we also need Processor
class, that distributed with original model and can be loaded using AutoProcessor
from transformers
library. Processor is responsible for input data preparation and decoding model output.
import requests
from PIL import Image
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
prompt = "<OD>"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
image
Let's check model capabilities in Object Detection.
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
from gradio_helper import plot_bbox
fig = plot_bbox(image, parsed_answer["<OD>"])
More model capabilities will be demonstrated in interactive demo.
In this section, you can see model in action on various of supported vision tasks. Please provide input image or select one from examples and specify task (Please note, that some of them may additionally requires to provide text input, e.g. description for region for segmentation or phrase for grounding).
from gradio_helper import make_demo
demo = make_demo(model, processor)
try:
demo.launch(debug=True, height=600)
except Exception:
demo.launch(debug=True, share=True, height=600)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/