1. 先定义残差18模块的网络
class Resnet18(nn.Module):
def __init__(self):
super().__init__()
model = models.resnet18(pretrained=True)
self.layer=nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.maxpool,
model.layer1,
model.layer2,
model.layer3,
model.layer4,
model.avgpool
)
def forward(self, x):
x=self.layer(x)
return x
添加到conv.py末尾
注册模块