今天看代码发现,自己对于网络中需要更新的参数并不是很熟悉,然后百度发现了这个方法,记录一下:
在自己定义的模型下面加入这一行就可以查看了:
for name, param in model.named_parameters():
print(name, ' ', param.size())
运行结果如下所示:
gnn.k_weight torch.Size([192, 64])
gnn.node.0.weight torch.Size([64, 19])
gnn.node.0.bias torch.Size([64])
gnn.node.1.weight torch.Size([64])
gnn.node.1.bias torch.Size([64])
gnn.conv_params.0.weight torch.Size([64, 64])
gnn.conv_params.0.bias torch.Size([64])
gnn.conv_params.1.weight torch.Size([64, 64])
gnn.conv_params.1.bias torch.Size([64])
gnn.conv_params.2.weight torch.Size([64, 64])
gnn.conv_params.2.bias torch.Size([64])
gnn.bn2.0.weight torch.Size([64])
gnn.bn2.0.bias torch.Size([64])
gnn.bn2.1.weight torch.Size([64])
gnn.bn2.1.bias torch.Size([64])
gnn.bn2.2.weight torch.Size([64])
gnn.bn2.2.bias torch.Size([64])
gnn.bn3.weight torch.Size([64])
gnn.bn3.bias torch.Size([64])
gnn.graph_attn.0.weight torch.Size([64, 64])
gnn.graph_attn.0.bias torch.Size([64])
gnn.graph_attn.2.weight torch.Size([64])
gnn.graph_attn.2.bias torch.Size([64])
gnn.graph_attn.3.weight torch.Size([1, 64])
gnn.graph_attn.3.bias torch.Size([1])
gnn.graph_attn.4.weight torch.Size([1])
gnn.graph_attn.4.bias torch.Size([1])
mlp.m.0.weight torch.Size([128, 64])
mlp.m.0.bias torch.Size([128])
mlp.m.2.weight torch.Size([2, 128])
mlp.m.2.bias torch.Size([2])
努力加油a啊