源码中有两个用于测试的脚本: test.py 和 evaluate_gpu.py 。其中, test.py 加载通过脚本 train.py 训练好的模型,实现对 query 和 gallery 图片的特征提取;本文对脚本 test.py 进行解析。
1. 加载模型和数据
首先需要载入训练好的模型,这里以基于 Resnet50 输出类别为 751 类的行人重识别模型 ft_net 为例。
model_structure = ft_net(751)
model = load_network(model_structure)
然后需要载入经过预处理的 gallery 和 query 数据集
data_transforms = transforms.Compose([
transforms.Resize((256,128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
image_datasets = {
x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle