图像分割:U-Net与MaskR-CNN实战
立即解锁
发布时间: 2025-09-01 01:57:42 阅读量: 4 订阅数: 13 AIGC 

# 图像分割:U-Net与Mask R-CNN实战
## 1. U-Net架构与上采样
U-Net架构在图像分割任务中表现出色,它在进行上采样时,利用跳跃连接(Skip Connections)从左半部分的对应层获取信息,类似于ResNet中的跳跃连接,这样可以保留原始图像的结构和物体信息。U-Net架构在利用卷积特征预测每个像素对应的类别时,能够学习保留原始图像的结构和物体形状。通常,输出的通道数与要预测的类别数相同。
### 1.1 上采样操作
在U-Net架构中,上采样使用`nn.ConvTranspose2d`方法实现。该方法接受输入通道数、输出通道数、内核大小和步幅作为输入参数。以下是一个示例计算:
```python
import torch
import torch.nn as nn
# 初始化网络
m = nn.ConvTranspose2d(1, 1, kernel_size=(2,2), stride=2, padding = 0)
# 初始化输入数组
input = torch.ones(1, 1, 3, 3)
output = m(input)
print(output.shape)
```
通过结合填充和步幅,我们将一个3x3形状的输入上采样为6x6形状的数组。在模型训练过程中,最优的滤波器值会学习尽可能地重建原始图像。
## 2. 使用U-Net实现语义分割
### 2.1 准备工作
首先,我们需要下载必要的数据集,安装必要的包,并导入它们,然后定义设备:
```python
import os
if not os.path.exists('dataset1'):
!wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip
!unzip -q dataset1.zip
!rm dataset1.zip
!pip install -q torch_snippets pytorch_model_summary
from torch_snippets import *
from torchvision import transforms
from sklearn.model_selection import train_test_split
device = 'cuda' if torch.cuda.is_available() else 'cpu'
```
### 2.2 定义图像转换函数
```python
tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
```
### 2.3 定义数据集类
```python
class SegData(Dataset):
def __init__(self, split):
self.items = stems(f'dataset1/images_prepped_{split}')
self.split = split
def __len__(self):
return len(self.items)
def __getitem__(self, ix):
image = read(f'dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)
image = cv2.resize(image, (224,224))
mask = read(f'dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png')[:,:,0]
mask = cv2.resize(mask, (224,224))
return image, mask
def choose(self):
return self[randint(len(self))]
def collate_fn(self, batch):
ims, masks = list(zip(*batch))
ims = torch.cat([tfms(im.copy()/255.)[None] for im in ims]).float().to(device)
ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(device)
return ims, ce_masks
```
### 2.4 定义训练和验证数据集及数据加载器
```python
trn_ds = SegData('train')
val_ds = SegData('test')
trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)
```
### 2.5 定义神经网络模型架构
```python
def conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def up_conv(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
from torchvision.models import vgg16_bn
class UNet(nn.Module):
def __init__(self, pretrained=True, out_channels=12):
super().__init__()
self.encoder = vgg16_bn(pretrained=pretrained).features
self.block1 = nn.Sequential(*self.encoder[:6])
self.block2 = nn.Sequential(*self.encoder[6:13])
self.block3 = nn.Sequential(*self.encoder[13:20])
self.block4 = nn.Sequential(*self.encoder[20:27])
self.block5 = nn.Sequential(*self.encoder[27:34])
self.bottleneck = nn.Sequential(*self.encoder[34:])
self.conv_bottleneck = conv(512, 1024)
self.up_conv6 = up_conv(1024, 512)
self.conv6 = conv(512 + 512, 512)
self.up_conv7 = up_conv(512, 256)
self.conv7 = conv(256 + 512, 256)
self.up_conv8 = up_conv(256, 128)
self.conv8 = conv(128 + 256, 128)
self.up_conv9 = up_conv(128, 64)
self.conv9 = conv(64 + 128, 64)
self.up_conv10 = up_conv(64, 32)
self.conv10 = conv(32 + 64, 32)
self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
bottleneck = self.bottleneck(block5)
x = self.conv_bottleneck(bottleneck)
x = self.up_conv6(x)
x = torch.cat([x, block5], dim=1)
x = self.conv6(x)
x = self.up_conv7(x)
x = torch.cat([x, block4], dim=1)
x = self.conv7(x)
x = self.up_conv8(x)
x = torch.cat([x, block3], dim=1)
x = self.conv8(x)
x = self.up_conv9(x)
x = torch.cat([x, block2], dim=1)
x = self.conv9(x)
x = self.up_conv10(x)
x = torch.cat([x, block1], dim=1)
x = self.conv10(x)
x = self.conv11(x)
return x
ce = nn.CrossEntropyLoss()
def UnetLoss(preds, targets):
ce_loss = ce(preds, targets)
acc = (torch.max(preds, 1)[1] == targets).float().mean()
return ce_loss, acc
def train_batch(model, data, optimizer, criterion):
model.train()
ims, ce_masks = data
_masks = model(ims)
optimizer.zero_grad()
loss, acc = criterion(_masks, ce_masks)
loss.backward()
optimizer.step()
return loss.item(), acc.item()
@torch.no_grad()
def validate_batch(model, data, criterion):
model.eval()
ims, masks = data
_masks = model
```
0
0
复制全文
相关推荐










