YOLOv8结合SAHI推理图像和视频


前言

在上一篇文章中,我们深入探讨了如何通过结合YOLOv8和SAHI来增强小目标检测效果
,并计算了相关评估指标,虽然我们也展示了可视化功能,但是这些功能往往需要结合实际的ground truth(GT)数据进行对比,这在实际操作中可能会稍显不便。

为了进一步简化操作,这篇文章将直接分享可以用来推理图像和视频的代码,通过这段代码,我们能够更加方便地使用SAHI进行小目标检测,而不需要反复处理和对比GT数据。

不多说啦,以下是完整的代码示例,供大家参考使用
在这里插入图片描述


视频效果

使用SAHI增强YOLOv8推理,提升小目标检测效果(附教程)


必要环境

  1. 参考上期博客
    地址:使用YOLOv8+SAHI增强小目标检测效果并计算评估指标

一、完整代码

import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import argparse
from tqdm import tqdm
import time

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='test/images', help='Path to the images folder or video file')
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output images or video')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='yolov8n.pt', help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.5, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')
parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')
parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')
parser.add_argument('--videos_format', type=str, nargs='+', default=['.mp4', '.avi'],
                    help='List of acceptable video formats')
args = parser.parse_args()


def get_color(idx):
    idx = int(idx) + 5
    return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)


def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )


def process_image(image_name, model):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
    )

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        name = pred.category.name
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

        cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                      get_color(cls), 2, cv2.LINE_AA)
        cv2.putText(img_vis, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.5, get_color(cls), thickness=2)

    if args.visualize_predictions:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # 保存结果图像
    output_path = os.path.join(args.output_dir, image_name)
    os.makedirs(args.output_dir, exist_ok=True)
    cv2.imwrite(output_path, img_vis)


def process_video(video_path, model):
    cap = cv2.VideoCapture(video_path)
    output_path = os.path.join(args.output_dir, os.path.basename(video_path))
    os.makedirs(args.output_dir, exist_ok=True)

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = None
    fps_list = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        start_time = time.time()

        result = get_sliced_prediction(
            frame,
            model,
            slice_height=args.slice_height,
            slice_width=args.slice_width,
            overlap_height_ratio=args.overlap_height_ratio,
            overlap_width_ratio=args.overlap_width_ratio,
        )

        for pred in result.object_prediction_list:
            bbox = pred.bbox
            cls = pred.category.id
            score = pred.score.value
            name = pred.category.name
            xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

            cv2.rectangle(frame, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls), 2, cv2.LINE_AA)
            cv2.putText(frame, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls), thickness=2)

        end_time = time.time()
        fps = 1 / (end_time - start_time)
        fps_list.append(fps)
        cv2.putText(frame, f"FPS: {fps:.2f}", (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 4)

        if out is None:
            frame_height, frame_width = frame.shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, cap.get(cv2.CAP_PROP_FPS), (frame_width, frame_height))

        out.write(frame)

        if args.visualize_predictions:
            cv2.imshow('Video', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    avg_fps = sum(fps_list) / len(fps_list)
    print(f"Average FPS: {avg_fps:.2f}")


def main():
    detection_model = load_detection_model()

    if os.path.isfile(args.filepath) and os.path.splitext(args.filepath)[1].lower() in args.videos_format:
        process_video(args.filepath, detection_model)
    else:
        image_names = [name for name in os.listdir(args.filepath) if
                       os.path.splitext(name)[1].lower() in args.images_format]

        for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
            process_image(image_name, detection_model)


if __name__ == "__main__":
    main()

二、运行方法

1、 推理图像

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath test/images  --output_dir output --model_type yolov8 --model_path yolov8n.pt
    

关键参数详解:

  • –filepath: 指定要推理的图像文件夹的路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述

2、 推理视频

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath inputvideo.mp4  --output_dir output --model_type yolov8 --model_path yolov8n.pt --visualize_predictions
    

关键参数详解:

  • –filepath: 指定要推理的视频路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小猫发财

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值