ResNet网络结构
导入必要的模块
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
定义ResnetbaseBlock函数
class ResnetbaseBlock(nn.Module):
def __init(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = F.relu(self.bn1(out), inplace=True)
out = self.conv2(out)
out = F.relu(out)
out += residual
return F.relu(out)