在下面代码中报错
for val_d in val_bar: #此行报错
val_image,val_label=val_d
output=net(val_image.to(device))
predict_y=torch.max(output,dim=1)
print(predict_y.shape)
检查,发现验证集的dataloader数据,在做transform时,没有加上ToTensor变换,使得输入网络的是PIL,而不是tensor。
在下面代码中报错
for val_d in val_bar: #此行报错
val_image,val_label=val_d
output=net(val_image.to(device))
predict_y=torch.max(output,dim=1)
print(predict_y.shape)
检查,发现验证集的dataloader数据,在做transform时,没有加上ToTensor变换,使得输入网络的是PIL,而不是tensor。