基于R-CNN与FastR-CNN的目标检测模型实现
立即解锁
发布时间: 2025-09-01 01:16:52 阅读量: 3 订阅数: 12 AIGC 


现代计算机视觉与PyTorch
### 基于R - CNN与Fast R - CNN的目标检测模型实现
#### 1. 数据准备
在进行目标检测模型训练之前,需要准备好训练和验证数据集以及数据加载器。具体步骤如下:
- 划分训练集和测试集:
```python
n_train = 9*len(FPATHS)//10
train_ds = RCNNDataset(FPATHS[:n_train],ROIS[:n_train],
CLSS[:n_train], DELTAS[:n_train],
GTBBS[:n_train])
test_ds = RCNNDataset(FPATHS[n_train:], ROIS[n_train:],
CLSS[n_train:], DELTAS[n_train:],
GTBBS[n_train:])
```
- 创建数据加载器:
```python
from torch.utils.data import TensorDataset, DataLoader
train_loader = DataLoader(train_ds, batch_size=2,
collate_fn=train_ds.collate_fn,
drop_last=True)
test_loader = DataLoader(test_ds, batch_size=2,
collate_fn=test_ds.collate_fn,
drop_last=True)
```
#### 2. R - CNN网络架构
准备好数据后,开始构建R - CNN模型,该模型可以预测区域建议的类别和对应的偏移量,以在图像中的对象周围绘制紧密的边界框。具体策略如下:
1. 定义VGG骨干网络:
```python
import torchvision.models as models
import torch.nn as nn
vgg_backbone = models.vgg16(pretrained=True)
vgg_backbone.classifier = nn.Sequential()
for param in vgg_backbone.parameters():
param.requires_grad = False
vgg_backbone.eval().to(device)
```
2. 定义RCNN网络模块:
```python
class RCNN(nn.Module):
def __init__(self):
super().__init__()
feature_dim = 25088
self.backbone = vgg_backbone
self.cls_score = nn.Linear(feature_dim, len(label2target))
self.bbox = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.ReLU(),
nn.Linear(512, 4),
nn.Tanh(),
)
self.cel = nn.CrossEntropyLoss()
self.sl1 = nn.L1Loss()
def forward(self, input):
feat = self.backbone(input)
cls_score = self.cls_score(feat)
bbox = self.bbox(feat)
return cls_score, bbox
def calc_loss(self, probs, _deltas, labels, deltas):
detection_loss = self.cel(probs, labels)
ixs, = torch.where(labels != 0)
_deltas = _deltas[ixs]
deltas = deltas[ixs]
self.lmb = 10.0
if len(ixs) > 0:
regression_loss = self.sl1(_deltas, deltas)
return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss.detach()
else:
regression_loss = 0
return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss
```
3. 定义训练和验证函数:
```python
import torch.optim as optim
def train_batch(inputs, model, optimizer, criterion):
input, clss, deltas = inputs
model.train()
optimizer.zero_grad()
_clss, _deltas = model(input)
loss, loc_loss, regr_loss = criterion(_clss, _deltas, clss, deltas)
accs = clss == decode(_clss)
loss.backward()
optimizer.step()
return loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()
@torch.no_grad()
def validate_batch(inputs, model, criterion):
input, clss, deltas = inputs
with torch.no_grad():
model.eval()
_clss,_deltas = model(input)
loss,loc_loss,regr_loss = criterion(_clss, _deltas, clss, deltas)
_, _clss = _clss.max(-1)
accs = clss == _clss
return _clss,_deltas,loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()
```
4. 创建模型对象,定义损失函数、优化器和训练轮数:
```python
rcnn = RCNN().to(device)
criterion = rcnn.calc_loss
optimizer = optim.SGD(rcnn.parameters(), lr=1e-3)
n_epochs = 5
log = Report(n_epochs)
```
5. 训练模型:
```python
for epoch in range(n_epochs):
_n = len(train_loader)
for ix, inputs in enumerate(train_loader):
loss, loc_loss,regr_loss,accs = train_batch(inputs, rcnn,
optimizer, criterion)
pos = (epoch + (ix+1)/_n)
log.record(pos, trn_loss=loss.item(),
trn_loc_loss=loc_loss,
trn_regr_loss=regr_loss,
trn_acc=accs.mean(), end='\r')
_n = len(test_loader)
for ix,inputs in enumerate(test_loader):
_clss, _deltas, loss, loc_loss, regr_loss, \
accs = validate_batch(inputs, rcnn, criterion)
pos = (epoch + (ix+1)/_n)
log.record(pos, val_loss=loss.item(),
val_loc_loss=loc_loss,
val_regr_loss=regr_loss,
val_acc=accs.mean(), end='\r')
# Plotting training and validation metrics
log.plot_epochs('trn_loss,val_loss'.split(','))
```
#### 3. 在新图像上进行预测
训练好模型后,使用它在新图像上进行预测,具体步骤如下:
1. 定义预测函数:
```python
import numpy as np
import cv2
import torch
def test_predictions(filename, show_output=True):
img = np.array(cv2.imread(filename, 1)[...,::-1])
candidates = extract_candidates(img)
candidates = [(x,y,x+w,y+h) for x,y,w,h in candidates]
input = []
for candidate in candidates:
x,y,X,Y = candidate
crop = cv2.resize(img[y:Y,x:X], (224,224))
input.append(preprocess_image(crop/255.)[None])
input = torch.cat(
```
0
0
复制全文
相关推荐










