%pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
%pip install openmim pycocotools faster-coco-eval
%pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
!python3 -m mim install mmdet
!wget -P COCO/DIR/ http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!wget -P COCO/DIR/ http://images.cocodataset.org/zips/val2017.zip
!unzip -qq COCO/DIR/annotations_trainval2017.zip -d COCO/DIR/
!unzip -qq COCO/DIR/val2017.zip -d COCO/DIR/
import mmdet
import mmengine
import os.path as osp
config_dir = osp.dirname(mmdet.__file__)
sub_config = "configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py"
config_file = osp.join(config_dir, ".mim", sub_config)
cfg = mmengine.Config.fromfile(config_file)
model_file = "https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth"
print(f"{config_file=}")
print(f"{model_file=}")
!mkdir -p -m 777 model
cfg.dump(osp.join("model", osp.basename(config_file)))
!wget -P model/ {model_file}
!ls -lah model
config_file='/home/mixaill76/.local/lib/python3.10/site-packages/mmdet/.mim/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py' model_file='https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth' --2024-06-20 16:31:25-- https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth Resolving download.openmmlab.com (download.openmmlab.com)... 47.246.2.228, 47.246.2.226, 47.246.2.230, ... Connecting to download.openmmlab.com (download.openmmlab.com)|47.246.2.228|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 22757492 (22M) [application/octet-stream] Saving to: ‘model/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth.2’ rtmdet-ins_tiny_8xb 100%[===================>] 21.70M 11.3MB/s in 1.9s 2024-06-20 16:31:28 (11.3 MB/s) - ‘model/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth.2’ saved [22757492/22757492] total 66M drwxrwxrwx 2 mixaill76 mixaill76 4.0K Jun 20 16:31 . drwxr-xr-x 4 mixaill76 mixaill76 4.0K Jun 13 14:38 .. -rw-r--r-- 1 mixaill76 mixaill76 16K Jun 20 16:31 rtmdet-ins_tiny_8xb32-300e_coco.py -rw-r--r-- 1 mixaill76 mixaill76 22M Dec 19 2022 rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth -rw-r--r-- 1 mixaill76 mixaill76 22M Dec 19 2022 rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth.1 -rw-r--r-- 1 mixaill76 mixaill76 22M Dec 19 2022 rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth.2
from mmdet.apis import inference_detector, init_detector
from mmengine.registry import init_default_scope
from mmdet.datasets import CocoDataset
import tqdm
import os.path as osp
import os
import torch
# from coco_metric import CocoMetric
from mmdet.evaluation import CocoMetric
from mmdet.structures.mask import encode_mask_results
import pathlib
import copy
import time
from pycocotools.coco import COCO as pycocotools_COCO
from pycocotools.cocoeval import COCOeval as pycocotools_COCOeval
from faster_coco_eval import COCO as COCO_faster, COCOeval_faster
import pandas as pd
from IPython.display import display, Markdown
init_default_scope("mmdet")
import json
with open("./COCO/DIR/annotations/instances_val2017.json") as fd:
instances_val2017 = json.load(fd)
image_id_for_eval = [image['id'] for image in instances_val2017['images']]
# image_id_for_eval = image_id_for_eval[:100] # Select first 100 images
annotations = [ann for ann in instances_val2017['annotations'] if ann['image_id'] in image_id_for_eval]
images = [image for image in instances_val2017['images'] if image['id'] in image_id_for_eval]
instances_val2017['annotations'] = annotations
instances_val2017['images'] = images
with open("./COCO/DIR/annotations/instances_val2017_first_100.json", "w") as fd:
json.dump(instances_val2017, fd)
model = init_detector(
"./model/rtmdet-ins_tiny_8xb32-300e_coco.py",
"./model/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth",
device=("cuda" if torch.cuda.is_available() else "cpu"),
)
Loads checkpoint by local backend from path: ./model/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth
pipeline = [
dict(type="LoadImageFromFile"),
dict(type="mmdet.LoadAnnotations", with_bbox=True),
]
dataset = CocoDataset(
data_root="./COCO/DIR/",
ann_file="annotations/instances_val2017_first_100.json",
data_prefix=dict(img="val2017/"),
pipeline=pipeline,
)
len(dataset)
loading annotations into memory... Done (t=0.35s) creating index... index created!
5000
metric = CocoMetric(metric=["bbox", "segm"])
metric.dataset_meta = model.dataset_meta
_coco_api = COCO_faster(dataset.ann_file)
metric.cat_ids = _coco_api.get_cat_ids(cat_names=metric.dataset_meta["classes"])
images_path = pathlib.Path(dataset.data_prefix["img"])
files = list(images_path.rglob("*.segm.json"))
files += list(images_path.rglob("*.bbox.json"))
for file in tqdm.tqdm(files):
os.remove(file.as_posix())
100%|██████████| 200/200 [00:00<00:00, 46410.00it/s]
max_images = len(dataset)
for i in tqdm.tqdm(range(max_images)):
item = dataset[i]
result = inference_detector(model, item["img_path"])
for key in result.pred_instances.all_keys():
result.pred_instances[key] = result.pred_instances[key].detach().cpu()
dict_result = dict(
result.pred_instances.to_dict(), **{"img_id": item["img_id"]}
)
if "masks" in dict_result:
dict_result["masks"] = encode_mask_results(
dict_result["masks"].detach().cpu().numpy()
)
metric.results2json(
[dict_result], outfile_prefix=osp.splitext(item["img_path"])[0]
)
0%| | 0/5000 [00:00<?, ?it/s]/home/mixaill76/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.) return F.conv2d(input, weight, bias, self.stride, /home/mixaill76/.local/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] 100%|██████████| 5000/5000 [04:24<00:00, 18.89it/s]
include_segm = "masks" in dict_result
print(f"{include_segm=}")
include_segm=True
dataset.data_prefix["img"]
'./COCO/DIR/val2017/'
images_path = pathlib.Path(dataset.data_prefix["img"])
if include_segm:
files = list(images_path.rglob("*.segm.json"))
else:
files = list(images_path.rglob("*.bbox.json"))
result_data = []
for file in tqdm.tqdm(files):
result_data += COCO_faster.load_json(file)
100%|██████████| 5000/5000 [00:02<00:00, 2200.00it/s]
def load_faster_data(ann_file, result_data):
cocoGt = COCO_faster(ann_file)
cocoDt = cocoGt.loadRes(copy.deepcopy(result_data))
return cocoGt, cocoDt
def process_faster(cocoGt, cocoDt, iouType):
cocoEval = COCOeval_faster(cocoGt, cocoDt, iouType, print_function=print)
ts = time.time()
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
te = time.time()
return te - ts
def load_pycocotools_data(ann_file, result_data):
cocoGt = pycocotools_COCO(ann_file)
cocoDt = cocoGt.loadRes(copy.deepcopy(result_data))
return cocoGt, cocoDt
def process_pycocotools(cocoGt, cocoDt, iouType):
cocoEval = pycocotools_COCOeval(cocoGt, cocoDt, iouType)
ts = time.time()
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
te = time.time()
return te - ts
processors = [
["faster-coco-eval", load_faster_data, process_faster],
["pycocotools", load_pycocotools_data, process_pycocotools],
]
result_table = {}
for metric in ["bbox", "segm"] if include_segm else ["bbox"]:
if result_table.get(metric) is None:
result_table[metric] = {}
for _name, _load, _process in processors:
if result_table[metric].get(_name) is None:
result_table[metric][_name] = 0
print(f"{metric=}; {_name=}")
cocoGt, cocoDt = _load(dataset.ann_file, result_data)
result_table[metric][_name] = _process(cocoGt, cocoDt, metric)
print()
print()
metric='bbox'; _name='faster-coco-eval' Evaluate annotation type *bbox* COCOeval_opt.evaluate() finished... DONE (t=4.47s). Accumulating evaluation results... COCOeval_opt.accumulate() finished... DONE (t=0.00s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.405 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.576 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.440 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.446 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.578 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.332 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.540 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.574 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.338 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.638 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.765 Average Recall (AR) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.796 Average Recall (AR) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.619 metric='bbox'; _name='pycocotools' loading annotations into memory... Done (t=0.17s) creating index... index created! Loading and preparing results... DONE (t=0.19s) creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=18.75s). Accumulating evaluation results... DONE (t=4.52s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.405 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.576 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.440 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.446 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.578 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.332 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.540 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.574 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.338 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.638 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.765 metric='segm'; _name='faster-coco-eval' Evaluate annotation type *segm* COCOeval_opt.evaluate() finished... DONE (t=6.07s). Accumulating evaluation results... COCOeval_opt.accumulate() finished... DONE (t=0.00s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.354 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.551 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.376 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.149 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.393 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.533 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.302 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.473 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.500 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.252 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.567 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.716 Average Recall (AR) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.755 Average Recall (AR) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.532 metric='segm'; _name='pycocotools' loading annotations into memory... Done (t=0.19s) creating index... index created! Loading and preparing results... DONE (t=0.14s) creating index... index created! Running per image evaluation... Evaluate annotation type *segm* DONE (t=19.70s). Accumulating evaluation results... DONE (t=4.19s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.354 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.551 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.376 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.149 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.393 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.533 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.302 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.473 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.500 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.252 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.567 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.716
result_table
{'bbox': {'faster-coco-eval': 4.5286314487457275, 'pycocotools': 23.2750403881073}, 'segm': {'faster-coco-eval': 6.140021085739136, 'pycocotools': 23.897135972976685}}
df = pd.DataFrame(result_table).T.round(3)
df.index.name = "Type"
df["Profit"] = (df["pycocotools"] / df["faster-coco-eval"]).round(3)
df
faster-coco-eval | pycocotools | Profit | |
---|---|---|---|
Type | |||
bbox | 4.529 | 23.275 | 5.139 |
segm | 6.140 | 23.897 | 3.892 |
print(df.to_markdown())
| Type | faster-coco-eval | pycocotools | Profit | |:-------|-------------------:|--------------:|---------:| | bbox | 4.529 | 23.275 | 5.139 | | segm | 6.14 | 23.897 | 3.892 |
display(Markdown(df.to_markdown()))
Type | faster-coco-eval | pycocotools | Profit |
---|---|---|---|
bbox | 4.529 | 23.275 | 5.139 |
segm | 6.14 | 23.897 | 3.892 |
filtred_result_data = [ann for ann in result_data if ann.get("score",0) > 0.3]
cocoGt, cocoDt = load_faster_data(dataset.ann_file, filtred_result_data)
from faster_coco_eval.extra import Curves
cur = Curves(cocoGt, cocoDt, iou_tresh=0.5, iouType="bbox", useCats=False)
cur.plot_pre_rec()
cur.plot_f1_confidence()
from faster_coco_eval.extra import PreviewResults
image_preview_count = 1
preview = PreviewResults(
cocoGt, cocoDt, iouType="bbox", iou_tresh=0.5
)
preview.display_tp_fp_fn(
data_folder=dataset.data_prefix["img"],
image_ids=list(cocoGt.imgs.keys())[10:10+image_preview_count],
display_gt=True,
)
preview.display_matrix(normalize=True)