# Ultralytics YOLO 🚀, AGPL-3.0许可证
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
一个扩展了BasePredictor类的类,用于基于分类模型进行预测。
注意:
- Torchvision的分类模型也可以通过'model'参数传递,例如model='resnet18'。
示例:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.classify import ClassificationPredictor
args = dict(model='yolov8n-cls.pt', source=ASSETS)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""使用可选的配置覆盖和回调函数初始化ClassificationPredictor对象。"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
def preprocess(self, img):
"""将输入图像转换为模型兼容的数据类型。"""
if not isinstance(img, torch.Tensor):
img = torch.stack([self.transforms(im) for im in img], dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""后处理预测结果以返回Results对象。"""
if not isinstance(orig_imgs, list): # 输入图像是torch.Tensor,而不是列表
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
return results
YOLO8分类任务-预测
于 2023-11-27 19:58:35 首次发布