两种可能:
一:transforms没有把数据集中的图片转换成想要的尺寸大小,如Resnet等网络输入需要224x224大小的图像,在Resize时,不能用 transforms.Resize(224),而因该用transforms.Resize([224, 224]),官方文档有函数解释
二. 可能是因为自己的数据集中既有RGB图像也有灰度图像,通道不统一
参考https://discuss.pytorch.org/t/runtimeerror-invalid-argument-0-sizes-of-tensors-must-match-except-in-dimension-0-got-3-and-2-in-dimension-1/23890
可以通过以下方法解决:
在__getitem__() 方法中添加 image= Image.open(image).convert('RGB')
我这里是第一种问题,把size改成源码要求的大小就可以了
链接:https://www.jianshu.com/p/9e866d02ddbd