pip install patched_yolo_infer
import cv2
from ultralytics import YOLO
from patched_yolo_infer import (
MakeCropsDetectThem,
CombineDetections,
visualize_results_usual_yolo_inference,
visualize_results,
)
To carry out patch-based inference of YOLO models using our library, you need to follow a sequential procedure. First, you create an instance of the MakeCropsDetectThem class, providing all desired parameters related to YOLO inference and the patch segmentation principle.
Subsequently, you pass the obtained object of this class to CombineDetections, which facilitates the consolidation of all predictions from each overlapping crop, followed by intelligent suppression of duplicates.
Upon completion, you receive the result, from which you can extract the desired outcome of frame processing.
The output obtained from the process includes several attributes that can be leveraged for further analysis or visualization:
img: This attribute contains the original image on which the inference was performed. It provides context for the detected objects.
confidences: This attribute holds the confidence scores associated with each detected object. These scores indicate the model's confidence level in the accuracy of its predictions.
boxes: These bounding boxes are represented as a list of lists, where each list contains four values: [x_min, y_min, x_max, y_max]. These values correspond to the coordinates of the top-left and bottom-right corners of each bounding box.
polygons: If available, this attribute provides a list containing NumPy arrays of polygon coordinates that represent segmentation masks corresponding to the detected objects. These polygons can be utilized to accurately outline the boundaries of each object.
classes_ids: This attribute contains the class IDs assigned to each detected object. These IDs correspond to specific object classes defined during the model training phase.
classes_names: These are the human-readable names corresponding to the class IDs. They provide semantic labels for the detected objects, making the results easier to interpret.
MakeCropsDetectThem - Class implementing cropping and passing crops through a neural network
for detection/segmentation.
Args:
image (np.ndarray): Input image BGR.
model_path (str): Path to the YOLO model.
imgsz (int): Size of the input image for inference YOLO.
conf (float): Confidence threshold for detections YOLO.
iou (float): IoU threshold for non-maximum suppression YOLOv8 of single crop.
classes_list (List[int] or None): List of classes to filter detections. If None,
all classes are considered. Defaults to None.
segment (bool): Whether to perform segmentation (YOLOv8-seg).
shape_x (int): Size of the crop in the x-coordinate.
shape_y (int): Size of the crop in the y-coordinate.
overlap_x (int): Percentage of overlap along the x-axis.
overlap_y (int): Percentage of overlap along the y-axis.
show_crops (bool): Whether to visualize the cropping.
resize_initial_size (bool): Whether to resize the results to the original
image size (ps: slow operation).
model: Pre-initialized model object. If provided, the model will be used directly instead of loading from model_path.
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
CombineDetections - Class implementing combining masks/boxes from multiple crops + NMS (Non-Maximum Suppression).
Args:
element_crops (MakeCropsDetectThem): Object containing crop information.
nms_threshold (float): IoU/IoS threshold for non-maximum suppression.
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
intelligent_sorter (bool): Enable sorting by area and rounded confidence parameter. If False, sorting will be done only by confidence (usual nms). (Dafault is True)
# Load the image
img_path = 'test_images/test-image-road.jpg'
img = cv2.imread(img_path)
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov8m.pt",
segment=False,
show_crops=False,
shape_x=600,
shape_y=600,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.05, match_metric='IOS')
Visualization:
print('Basic yolo inference:')
visualize_results_usual_yolo_inference(
img,
model=YOLO("yolov8m.pt") ,
imgsz=640,
conf=0.5,
iou=0.7,
thickness=8,
font_scale=1.1,
show_boxes=True,
delta_colors=3,
show_class=False,
axis_off=False
)
print('YOLO-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
thickness=8,
font_scale=1.1,
show_boxes=True,
delta_colors=3,
show_class=False,
axis_off=False
)
Basic yolo inference:
YOLO-Patch-Based-Inference:
PS: You can learn more about how to work with functions for visualizing detection and segmentation results in the examples/example_extra_functions.ipynb notebook.
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov8m.pt",
segment=False,
show_crops=True,
shape_x=600,
shape_y=600,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=False,
)
result = CombineDetections(element_crops, nms_threshold=0.05, match_metric='IOS')
print('Before nms:')
visualize_results(
img=result.image,
confidences=result.detected_conf_list_full,
boxes=result.detected_xyxy_list_full,
classes_ids=result.detected_cls_id_list_full,
classes_names=result.detected_cls_names_list_full,
thickness=8,
font_scale=1.2,
show_boxes=True,
delta_colors=3,
dpi=200,
)
print('After nms:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
thickness=8,
font_scale=1.2,
show_boxes=True,
delta_colors=3,
dpi=200,
)
Number of generated images: 45 Before nms:
After nms:
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov9c.pt",
segment=False,
show_crops=False,
shape_x=600,
shape_y=500,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.05, match_metric='IOS')
print('Basic yolo inference:')
visualize_results_usual_yolo_inference(
img,
model=YOLO("yolov9c.pt") ,
imgsz=640,
conf=0.5,
iou=0.7,
segment=False,
thickness=8,
font_scale=1.1,
show_boxes=True,
delta_colors=3,
show_class=False,
)
print('YOLO-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
segment=False,
thickness=8,
show_boxes=True,
delta_colors=3,
show_class=False,
)
Basic yolo inference:
YOLO-Patch-Based-Inference:
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov8m-seg.pt",
segment=True,
show_crops=False,
shape_x=600,
shape_y=600,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.5, match_metric='IOS')
Visualization:
print('Basic yolo inference:')
visualize_results_usual_yolo_inference(
img,
model=YOLO("yolov8m-seg.pt") ,
imgsz=640,
conf=0.5,
iou=0.7,
segment=True,
thickness=8,
font_scale=1.1,
fill_mask=True,
show_boxes=False,
delta_colors=3,
show_class=False,
axis_off=False
)
print('YOLO-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
polygons=result.filtered_polygons,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
segment=True,
thickness=8,
font_scale=1.1,
fill_mask=True,
show_boxes=False,
delta_colors=3,
show_class=False,
axis_off=False
)
Basic yolo inference:
YOLO-Patch-Based-Inference:
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov9e-seg.pt",
segment=True,
show_crops=False,
shape_x=600,
shape_y=500,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.5, match_metric='IOS')
print('Basic yolo inference:')
visualize_results_usual_yolo_inference(
img,
model=YOLO("yolov9e-seg.pt") ,
imgsz=640,
conf=0.5,
iou=0.7,
segment=True,
thickness=8,
font_scale=1.1,
fill_mask=True,
show_boxes=False,
delta_colors=3,
show_class=False,
axis_off=False
)
print('YOLO-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
polygons=result.filtered_polygons,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
segment=True,
thickness=8,
font_scale=1.1,
fill_mask=True,
show_boxes=False,
delta_colors=3,
show_class=False,
axis_off=False
)
Basic yolo inference:
YOLO-Patch-Based-Inference:
from ultralytics import RTDETR
# Load the image
img_path = 'test_images/road.jpg'
img = cv2.imread(img_path)
element_crops = MakeCropsDetectThem(
image=img,
model=RTDETR('rtdetr-l.pt'),
segment=False,
shape_x=450,
shape_y=300,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.8,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.25, match_metric='IOS')
print('Basic rtdetr inference:')
visualize_results_usual_yolo_inference(
img,
model=RTDETR('rtdetr-l.pt') ,
imgsz=640,
conf=0.5,
iou=0.7,
thickness=6,
show_boxes=True,
delta_colors=3,
show_class=False,
)
print('rtdetr-patch-based-inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
thickness=6,
show_boxes=True,
delta_colors=3,
show_class=False,
)
Basic rtdetr inference:
rtdetr-patch-based-inference:
from ultralytics import FastSAM
import matplotlib.pyplot as plt
# Load the image
img_path = 'test_images/stones.jpg'
img = cv2.imread(img_path)
plt.imshow(cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB))
plt.show()
element_crops = MakeCropsDetectThem(
image=img,
model=FastSAM('FastSAM-x.pt'),
model_path="yolov8m.pt",
segment=True,
show_crops=True,
shape_x=400,
shape_y=300,
overlap_x=50,
overlap_y=50,
conf=0.3,
iou=0.8,
resize_initial_size=True,
)
result = CombineDetections(element_crops, nms_threshold=0.40, match_metric='IOS')
print('Basic FastSAM inference:')
visualize_results_usual_yolo_inference(
img,
model=FastSAM('FastSAM-x.pt') ,
imgsz=640,
conf=0.3,
iou=0.8,
segment=True,
thickness=6,
fill_mask=True,
show_boxes=False,
show_class=False,
random_object_colors=True,
)
print('FastSAM-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
polygons=result.filtered_polygons,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
segment=True,
thickness=6,
fill_mask=True,
show_boxes=False,
show_class=False,
random_object_colors=True
)
Number of generated images: 20 Basic FastSAM inference:
FastSAM-Patch-Based-Inference:
In this approach, all operations under the hood are performed on binary masks of recognized objects. Storing these masks consumes a lot of memory, so this method requires more RAM and slightly more processing time. However, the accuracy of recognition significantly improves, which is especially noticeable in cases where there are many objects of different sizes and they are densely packed. Therefore, we recommend using this approach in production if accuracy is important and not speed, and if your computational resources allow storing hundreds of binary masks in RAM.
The difference in the approach to using the function lies in specifying the parameter memory_optimize=False
in the MakeCropsDetectThem
class.
In such a case, the informative values after processing will be the following:
img: This attribute contains the original image on which the inference was performed. It provides context for the detected objects.
confidences: This attribute holds the confidence scores associated with each detected object. These scores indicate the model's confidence level in the accuracy of its predictions.
boxes: These bounding boxes are represented as a list of lists, where each list contains four values: [x_min, y_min, x_max, y_max]. These values correspond to the coordinates of the top-left and bottom-right corners of each bounding box.
masks: This attribute provides segmentation binary masks corresponding to the detected objects. These masks can be used to precisely delineate object boundaries.
classes_ids: This attribute contains the class IDs assigned to each detected object. These IDs correspond to specific object classes defined during the model training phase.
classes_names: These are the human-readable names corresponding to the class IDs. They provide semantic labels for the detected objects, making the results easier to interpret.
Here's how you can obtain them:
img=result.image
confidences=result.filtered_confidences
boxes=result.filtered_boxes
masks=result.filtered_masks
classes_ids=result.filtered_classes_id
classes_names=result.filtered_classes_names
# Load the image
img_path = 'test_images/road.jpg'
img = cv2.imread(img_path)
Example:
element_crops = MakeCropsDetectThem(
image=img,
model_path="yolov9e-seg.pt",
segment=True,
show_crops=True,
shape_x=350,
shape_y=300,
overlap_x=50,
overlap_y=50,
conf=0.5,
iou=0.7,
imgsz=416,
classes_list=[0, 1, 2, 3, 5, 7],
resize_initial_size=True,
memory_optimize=False
)
result = CombineDetections(element_crops, nms_threshold=0.5, match_metric='IOS')
print('YOLO-Patch-Based-Inference:')
visualize_results(
img=result.image,
confidences=result.filtered_confidences,
boxes=result.filtered_boxes,
masks=result.filtered_masks,
classes_ids=result.filtered_classes_id,
classes_names=result.filtered_classes_names,
segment=True,
thickness=6,
font_scale=1.1,
fill_mask=True,
show_boxes=False,
delta_colors=3,
show_class=False,
axis_off=False
)
Number of generated images: 40 YOLO-Patch-Based-Inference:
import matplotlib.pyplot as plt
i = 11 # crop(patch) number
plt.imshow(cv2.cvtColor(element_crops.crops[i].crop.copy(), cv2.COLOR_BGR2RGB))
plt.show()
visualize_results(
img=element_crops.crops[i].crop,
confidences=element_crops.crops[i].detected_conf,
boxes=element_crops.crops[i].detected_xyxy,
classes_ids=element_crops.crops[i].detected_cls,
masks=element_crops.crops[i].detected_masks,
segment=True,
thickness=1,
font_scale=1.2,
fill_mask=True,
show_boxes=True,
delta_colors=3,
dpi=65,
show_class=False
)
visualize_results(
img=element_crops.crops[i].source_image,
confidences=element_crops.crops[i].detected_conf,
boxes=element_crops.crops[i].detected_xyxy_real,
classes_ids=element_crops.crops[i].detected_cls,
masks=element_crops.crops[i].detected_masks_real,
segment=True,
thickness=1,
font_scale=1.2,
fill_mask=True,
show_boxes=True,
delta_colors=3,
dpi=150,
show_class=False
)