How to install
Through PyPi
pip install thop
Using GitHub (always latest)
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
import torch
import torch.nn as nn
import torchvision
from functools import reduce
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
def FirstConv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=1,
bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def LastConv(in_channels, out_channels, mid_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1,
groups=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.PixelShuffle(upscale_factor=2),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
groups=out_channels // 4),
nn.BatchNorm2d(out_channels),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1,groups=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
class ChannelShuffle(nn.Module):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
## channel shuffle: [N, C, H, W] -> [N, g, C/g, H, W] -> [N, C/g, g, H, w] -> [N, C, H, W]
N, C, H, W = x.shape
g = self.groups
return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
class EltwiseAdd(nn.Module):
def __init__(self):
super(EltwiseAdd, self).__init__()
def forward(self, *input):
assert len(input) > 1
return reduce(lambda x, y: x + y, input)
class Concat(nn.Module):
def __init__(self, out_channels=None):
super(Concat, self).__init__()
self.out_channels = out_channels
def forward(self, *input, dim=1):
output = torch.cat(input, dim=dim)
if self.out_channels is not None:
assert output.shape[1] == self.out_channels, 'The actual number of output channels not match the predefined.'
return output
class HalfSplit(nn.Module):
def __init__(self, dim=0, first_half=True):
super(HalfSplit, self).__init__()
self.first_half = first_half
self.dim = dim
def forward(self, input):
splits = torch.chunk(input, 2, dim=self.dim)
return splits[0] if self.first_half else splits[1]
class ShuffleBottleneck(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, mid_stride=None):
super(ShuffleBottleneck, self).__init__()
mid_channels = mid_channels or in_channels
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1, groups=2,bias=True)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.relu1 = nn.ReLU(inplace=True)
self.channel_shuffle = ChannelShuffle(groups=2)
self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=3, stride=mid_stride,padding=1, groups=mid_channels // 4,bias=True)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(in_channels=in_channels, out_channels, kernel_size=1, stride=1, groups=2,bias=True)
self.bn3 = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.channel_shuffle(x)
x = self.bn2(self.conv2(x))
x = self.bn3(self.conv3(x))
return x
class ShuffleResUnitConcat(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super(ShuffleResUnitConcat, self).__init__()
mid_channels = mid_channels or in_channels
self.branch1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.branch2 = ShuffleBottleneck(in_channels, out_channels - in_channels, mid_channels,mid_stride=2)
self.concat = Concat(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = self.branch1(x)
x = self.branch2(x)
return self.relu(self.concat(residual, x))
class ShuffleResUnitEltwiseAdd(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super(ShuffleResUnitEltwiseAdd, self).__init__()
mid_channels = out_channels // 2
self.first_half = HalfSplit(dim=1, first_half=True)
self.second_split = HalfSplit(dim=1, first_half=False)
self.branch2 = ShuffleBottleneck(mid_channels, mid_channels, mid_stride=1)
self.concat = Concat(out_channels)
def forward(self, x):
x1 = self.first_half(x)
x2 = self.second_split(x)
x2 = self.branch2(x2)
return self.concat(x1, x2)
class ASPPBottleneck(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, dilation=None):
super(ASPPBottleneck, self).__init__()
mid_channels = mid_channels or in_channels
self.dilation_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1,padding=dilation,dilation=dilation,groups=in_channels // 4)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.dilation_conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, stride=1,dilation=dilation,groups=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.bn1(self.dilation_conv1(x))
x = self.bn2(self.dilation_conv2(x))
return self.relu(x)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, dilation1=1, dilation2=4, dilation3=8,
dilation4=16):
super(ASPP, self).__init__()
mid_channels = mid_channels or out_channels // 4
self.dilation1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1,
dilation=dilation1, groups=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.dilation2 = ASPPBottleneck(in_channels=in_channels, out_channels=mid_channels, dilation=dilation2)
self.dilation3 = ASPPBottleneck(in_channels=in_channels, out_channels=mid_channels, dilation=dilation3)
self.dilation4 = ASPPBottleneck(in_channels=in_channels, out_channels=mid_channels, dilation=dilation4)
self.concat = Concat(out_channels=out_channels)
self.tail = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, groups=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.dilation1(x)
x2 = self.dilation2(x)
x3 = self.dilation3(x)
x4 = self.dilation4(x)
out = self.concat(x1, x2, x3, x4)
return self.tail(out)
class ParkslotNet(nn.Module):
def __init__(self, reg_num_classes=8, seg_num_classes=4, pt_num_classes=4, num_classes=16):
super(ParkslotNet, self).__init__()
self.first_conv = FirstConv(in_channels=3, out_channels=24)
self.res1a = ShuffleResUnitConcat(in_channels=24, out_channels=72)
self.res1b = ShuffleResUnitEltwiseAdd(in_channels=72, out_channels=72)
self.res1b_channel_shuffle = ChannelShuffle(groups=2)
self.res1c = ShuffleResUnitEltwiseAdd(in_channels=72, out_channels=72)
self.res1c_channel_shuffle = ChannelShuffle(groups=2)
self.res1d = ShuffleResUnitEltwiseAdd(in_channels=72, out_channels=72)
self.res1d_channel_shuffle = ChannelShuffle(groups=2)
self.res2a = ShuffleResUnitConcat(in_channels=72, out_channels=120)
self.res2b = ShuffleResUnitEltwiseAdd(in_channels=120, out_channels=120)
self.res2b_channel_shuffle = ChannelShuffle(groups=2)
self.res2c = ShuffleResUnitEltwiseAdd(in_channels=120, out_channels=120)
self.res2c_channel_shuffle = ChannelShuffle(groups=2)
self.res2d = ShuffleResUnitEltwiseAdd(in_channels=120, out_channels=120)
self.res2d_channel_shuffle = ChannelShuffle(groups=2)
self.res2_aspp = ASPP(in_channels=120, out_channels=96)
self.res3a = ShuffleResUnitConcat(in_channels=120, out_channels=240)
self.res3b = ShuffleResUnitEltwiseAdd(in_channels=240, out_channels=240)
self.res3b_channel_shuffle = ChannelShuffle(groups=2)
self.res3c = ShuffleResUnitEltwiseAdd(in_channels=240, out_channels=240)
self.res3c_channel_shuffle = ChannelShuffle(groups=2)
self.res3d = ShuffleResUnitEltwiseAdd(in_channels=240, out_channels=240)
self.res3d_channel_shuffle = ChannelShuffle(groups=2)
self.res3f = ShuffleResUnitEltwiseAdd(in_channels=240, out_channels=240)
self.res3f_channel_shuffle = ChannelShuffle(groups=2)
self.res3_aspp = ASPP(in_channels=240, out_channels=144)
self.res3_pixelshuffle = nn.PixelShuffle(upscale_factor=2)
self.res4a = ShuffleResUnitConcat(in_channels=240, out_channels=480)
self.res4b = ShuffleResUnitEltwiseAdd(in_channels=480, out_channels=480)
self.res4b_channel_shuffle = ChannelShuffle(groups=2)
self.res4c = ShuffleResUnitEltwiseAdd(in_channels=480, out_channels=480)
self.res4c_channel_shuffle = ChannelShuffle(groups=2)
self.res4d = ShuffleResUnitEltwiseAdd(in_channels=480, out_channels=480)
self.res4d_channel_shuffle = ChannelShuffle(groups=2)
self.res4_aspp = ASPP(in_channels=480, out_channels=384)
self.res4_pixelshuffle = nn.PixelShuffle(upscale_factor=4)
self.concat = Concat(out_channels=156)
self.last_conv = LastConv(in_channels=156, out_channels=36, mid_channels=144)
self.conv_reg = nn.Conv2d(in_channels=36, out_channels=reg_num_classes, kernel_size=1, stride=1, groups=1)
self.conv_seg = nn.Conv2d(in_channels=36, out_channels=seg_num_classes, kernel_size=1, stride=1, groups=1)
self.conv_pt = nn.Conv2d(in_channels=36, out_channels=pt_num_classes, kernel_size=1, stride=1, groups=1)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.running_var, 1)
nn.init.constant_(m.running_mean, 0)
def forward(self, x, labels=None, mode='eval'):
## conv1
x = self.first_conv(x)
## res1
x = self.res1a(x)
x = self.res1b_channel_shuffle(self.res1b(x))
x = self.res1c_channel_shuffle(self.res1c(x))
x = self.res1d_channel_shuffle(self.res1d(x))
## res2
x2 = self.res2a(x)
x2 = self.res2b_channel_shuffle(self.res2b(x2))
x2 = self.res2c_channel_shuffle(self.res2c(x2))
x2 = self.res2d_channel_shuffle(self.res2d(x2))
x2_residual = self.res2_aspp(x2)
## res3
x3 = self.res3a(x2)
x3 = self.res3b_channel_shuffle(self.res3b(x3))
x3 = self.res3c_channel_shuffle(self.res3c(x3))
x3 = self.res3d_channel_shuffle(self.res3d(x3))
x3 = self.res3f_channel_shuffle(self.res3f(x3))
x3_residual = self.res3_aspp(x3)
x3_residual = self.res3_pixelshuffle(x3_residual)
##res4
x4 = self.res4a(x3)
x4 = self.res4b_channel_shuffle(self.res4b(x4))
x4 = self.res4c_channel_shuffle(self.res4c(x4))
x4 = self.res4d_channel_shuffle(self.res4d(x4))
x4_residual = self.res4_aspp(x4)
x4_residual = self.res4_pixelshuffle(x4_residual)
##concat
out = self.concat(x2_residual, x3_residual, x4_residual)
out = self.last_conv(out)
##output
out1 = self.conv_reg(out)
out2 = self.conv_seg(out)
out3 = self.conv_pt(out)
return out1, out2, out3
if __name__ == '__main__':
Net = ParkslotNet()
print(Net)
print('# generator parameters:', sum(param.numel() for param in Net.parameters()))
for para1, para2 in zip(Net.named_parameters(), Net.parameters()):
print(para1[0], para2.numel())
input = torch.randn(1, 3, 640, 640)
out1, out2, out3 = Net(input)
print(out1.shape)
print(out2.shape)
print(out3.shape)