0. 往期内容
[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换
[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)
[八]深度学习Pytorch-图像预处理transforms
[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)
[十]深度学习Pytorch-transforms图像操作及自定义方法
[十一]深度学习Pytorch-模型创建与nn.Module
[十二]深度学习Pytorch-模型容器与AlexNet构建
[十三]深度学习Pytorch-卷积层(1D/2D/3D卷积、卷积nn.Conv2d、转置卷积nn.ConvTranspose)
[十六]深度学习Pytorch-18种损失函数loss function
[十八]深度学习Pytorch-学习率Learning Rate调整策略
[十九]深度学习Pytorch-可视化工具TensorBoard
[二十一]深度学习Pytorch-正则化Regularization之weight decay
[二十二]深度学习Pytorch-正则化Regularization之dropout
[二十三]深度学习Pytorch-批量归一化Batch Normalization
[二十四]深度学习Pytorch-BN、LN(Layer Normalization)、IN(Instance Normalization)、GN(Group Normalization)
深度学习Pytorch-图像分割Unet
1. 图像分割定义
2. 模型是如何将图像分割的?
代码示例:
# -*- coding: utf-8 -*-
"""
# @file name : seg_demo.py
# @brief : torch.hub调用deeplab-V3进行图像分割
"""
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
# path_img = os.path.join(BASE_DIR, "demo_img1.png")
# path_img = os.path.join(BASE_DIR, "demo_img2.png")
path_img = os.path.join(BASE_DIR, "demo_img3.png")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
model.eval()
# 2. preprocess
input_tensor = preprocess(input_image)
input_bchw = input_tensor.unsqueeze(0)
# 3. to device
if torch.cuda.is_available():
input_bchw = input_bchw.to(device)
model.to(device)
# 4. forward
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_bchw.shape)) #1*3*433*649
output_4d = model(input_bchw)['out'] #获取out的value,4D
output = output_4d[0] #3D
print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
print("output img tensor shape:{}".format(output.shape)) #输出21*433*649
output_predictions = output.argmax(0) #获取每一个像素上的分类,output_predictions 2D
# 5. visualization
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize