如何利用ResNet18预训练模型在CUB-200-2011数据集上进行迁移学习实现图像分类?请结合Python和PyTorch详细说明步骤。

时间: 2024-11-05 17:17:51 AIGC 浏览: 181
在进行深度学习项目时,将预训练模型应用于特定任务是一种常见且高效的方法。CUB-200-2011数据集包含了丰富多样的鸟类图像,非常适合用来训练和测试图像分类模型。本回答将基于《利用ResNet18实现CUB-200-2011鸟类图像分类》的资源来详细说明如何在PyTorch框架中使用ResNet18预训练模型进行迁移学习以实现图像分类。 参考资源链接:[利用ResNet18实现CUB-200-2011鸟类图像分类](https://wenku.csdn.net/doc/7tk4e8vkh0?spm=1055.2569.3001.10343) 首先,确保你已经安装了PyTorch以及torchvision库。接下来,下载CUB-200-2011数据集和提供的预训练模型源代码。加载预训练的ResNet18模型,并冻结除最后几层以外的全部层的权重,这样可以在新任务上进行微调。 在PyTorch中,可以通过修改model.fc(全连接层)来匹配数据集的类别数。然后,对模型进行训练,可以使用数据集提供的训练集,验证集用于模型评估。在此过程中,建议逐步调整学习率,并使用交叉验证等技术来优化模型性能。 使用提供的python源码中的train.py文件可以开始模型训练过程,并根据模型在验证集上的表现来调整超参数。此外,还可以通过查看runs目录来监控模型训练过程中的损失曲线和准确率变化,以此来决定是否需要调整学习率或进行模型的进一步训练。 在模型训练完成后,使用test.py文件对模型进行最终的测试,并评估其在CUB-200-2011数据集上的分类性能。通过这种方式,你可以充分理解并掌握如何使用预训练模型进行迁移学习,并在特定数据集上取得良好的分类结果。如果你想进一步了解细节,例如数据预处理、模型训练的超参数设置、以及如何进行模型评估等,推荐深入研究《利用ResNet18实现CUB-200-2011鸟类图像分类》所提供的资源。这份资源详细描述了每一个步骤,并提供了必要的代码,能帮助你更好地理解和掌握迁移学习在图像分类中的应用。 参考资源链接:[利用ResNet18实现CUB-200-2011鸟类图像分类](https://wenku.csdn.net/doc/7tk4e8vkh0?spm=1055.2569.3001.10343)
阅读全文

相关推荐

import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models from torch.utils.data import Dataset,DataLoader import os import numpy as np from PIL import Image # ==================== CBAM模块实现 ==================== class ChannelAttention(nn.Module): def __init__(self, in_planes, reduction_ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_planes, in_planes // reduction_ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // reduction_ratio, in_planes, 1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) return self.sigmoid(avg_out + max_out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): def __init__(self, channel, reduction_ratio=16, kernel_size=7): super(CBAM, self).__init__() self.ca = ChannelAttention(channel, reduction_ratio) self.sa = SpatialAttention(kernel_size) def forward(self, x): x = x * self.ca(x) x = x * self.sa(x) return x # ==================== 修改ResNet50 ==================== class BottleneckWithCBAM(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BottleneckWithCBAM, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.cbam = CBAM(planes * self.expansion) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) out = self.cbam(out) return out def resnet50_cbam(pretrained=True, num_classes=200): model = models.resnet50(pretrained=pretrained) # 替换Bottleneck模块 for i in range(len(model.layer1)): model.layer1[i] = BottleneckWithCBAM(model.layer1[i].conv1.in_channels, model.layer1[i].conv1.out_channels) for i in range(len(model.layer2)): model.layer2[i] = BottleneckWithCBAM(model.layer2[i].conv1.in_channels, model.layer2[i].conv1.out_channels) for i in range(len(model.layer3)): model.layer3[i] = BottleneckWithCBAM(model.layer3[i].conv1.in_channels, model.layer3[i].conv1.out_channels) for i in range(len(model.layer4)): model.layer4[i] = BottleneckWithCBAM(model.layer4[i].conv1.in_channels, model.layer4[i].conv1.out_channels) # 修改最后的全连接层 model.fc = nn.Linear(model.fc.in_features, num_classes) return model # ==================== 数据集解析与划分 ==================== class CUBDataset(Dataset): def __init__(self, root_dir, transform=None, is_train=True): self.root = root_dir self.transform = transform # 解析数据文件 self._parse_files() # 根据train_test_split筛选数据 split_mask = (self.split_labels == 1) if is_train else (self.split_labels == 0) self.image_paths = self.image_paths[split_mask] self.class_labels = self.class_labels[split_mask] def _parse_files(self): # 读取images.txt with open(os.path.join(self.root, 'images.txt'), 'r') as f: image_entries = [line.strip().split() for line in f] self.image_paths = np.array([os.path.join(self.root, 'images', e[1]) for e in image_entries]) # 读取image_class_labels.txt with open(os.path.join(self.root, 'image_class_labels.txt'), 'r') as f: class_entries = [line.strip().split() for line in f] self.class_labels = np.array([int(e[1]) - 1 for e in class_entries]) # 类别索引从0开始 # 读取train_test_split.txt with open(os.path.join(self.root, 'train_test_split.txt'), 'r') as f: split_entries = [line.strip().split() for line in f] self.split_labels = np.array([int(e[1]) for e in split_entries]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') label = self.class_labels[idx] if self.transform: image = self.transform(image) return image, label # ==================== 数据预处理 ==================== def get_dataloaders(root_dir, batch_size=32): train_transforms = transforms.Compose([ transforms.Resize(512), # 高分辨率有利于细粒度分类 transforms.RandomCrop(448), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transforms = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(448), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = CUBDataset(root_dir, train_transforms, is_train=True) test_dataset = CUBDataset(root_dir, test_transforms, is_train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader # ==================== 训练函数 ==================== def train_model(model, train_loader, test_loader, device, num_epochs=100): criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=3e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) best_acc = 0.0 for epoch in range(num_epochs): model.train() running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_loader.dataset) scheduler.step() # 验证阶段 model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() acc = 100 * correct / total print(f'Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f} Acc: {acc:.2f}%') if acc > best_acc: best_acc = acc torch.save(model.state_dict(), 'best_model.pth') print(f'Best Test Accuracy: {best_acc:.2f}%') return model # ==================== 主函数 ==================== if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data_dir = r'D:\CUB_200_2011\CUB_200_2011' # 自动划分数据集 train_loader, test_loader = get_dataloaders(data_dir, batch_size=64) # 初始化模型(使用前文定义的resnet50_cbam) model = resnet50_cbam(pretrained=True, num_classes=200) model = model.to(device) # 训练流程(使用前文定义的train_model函数) train_model(model, train_loader, test_loader, device, num_epochs=100) The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=ResNet50_Weights.IMAGENET1K_V1. You can also use weights=ResNet50_Weights.DEFAULT to get the most up-to-date weights. Traceback (most recent call last): File "C:\Users\Administrator\PyCharmMiscProject\test.py", line 254, in <module> train_model(model, train_loader, test_loader, device, num_epochs=100) File "C:\Users\Administrator\PyCharmMiscProject\test.py", line 208, in train_model outputs = model(inputs) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torchvision\models\resnet.py", line 285, in forward return self._forward_impl(x) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torchvision\models\resnet.py", line 273, in _forward_impl x = self.layer1(x) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\container.py", line 217, in forward input = module(input) File "C:\ProgramData\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "C:\Users\Administrator\PyCharmMiscProject\test.py", line 90, in forward out += identity RuntimeError: The size of tensor a (256) must match the size of tensor b (64) at non-singleton dimension 1

最新推荐

recommend-type

tock-nlp-admin-server-22.3.0-sources.jar

tock-nlp-admin-server-22.3.0-sources.jar
recommend-type

tock-bot-connector-teams-22.9.2.jar

tock-bot-connector-teams-22.9.2.jar
recommend-type

【scratch2.0少儿编程-游戏原型-动画-项目源码】接女孩.zip

资源说明: 1:本资料仅用作交流学习参考,请切勿用于商业用途。更多精品资源请访问 https://blog.csdn.net/ashyyyy/article/details/146464041 2:一套精品实用scratch2.0少儿编程游戏、动画源码资源,无论是入门练手还是项目复用都超实用,省去重复开发时间,让开发少走弯路!
recommend-type

cybrid-api-bank-java-v0.75.3.jar

cybrid-api-bank-java-v0.75.3.jar
recommend-type

mutable-utils-1.0.4-javadoc.jar

mutable-utils-1.0.4-javadoc.jar
recommend-type

Node.js构建的运动咖啡馆RESTful API介绍

标题《sportscafeold:体育咖啡馆》指出了项目名称为“体育咖啡馆”,这个名字暗示了该项目可能是一个结合了运动和休闲主题的咖啡馆相关的网络服务平台。该项目运用了多种技术栈,核心的开发语言为JavaScript,这从标签中可以得到明确的信息。 从描述中可以提取以下知识点: 1. **Node.js**:体育咖啡馆项目使用了Node.js作为服务器端运行环境。Node.js是一个基于Chrome V8引擎的JavaScript运行环境,它能够使得JavaScript应用于服务器端开发。Node.js的事件驱动、非阻塞I/O模型使其适合处理大量并发连接,这对于RESTFUL API的构建尤为重要。 2. **Express Framework**:项目中使用了Express框架来创建RESTFUL API。Express是基于Node.js平台,快速、灵活且极简的Web应用开发框架。它提供了构建Web和移动应用的强大功能,是目前最流行的Node.js Web应用框架之一。RESTFUL API是一组遵循REST原则的应用架构,其设计宗旨是让Web服务通过HTTP协议进行通信,并且可以使用各种语言和技术实现。 3. **Mongoose ORM**:这个项目利用了Mongoose作为操作MongoDB数据库的接口。Mongoose是一个对象文档映射器(ODM),它为Node.js提供了MongoDB数据库的驱动。通过Mongoose可以定义数据模型,进行数据库操作和查询,从而简化了对MongoDB数据库的操作。 4. **Passport.js**:项目中采用了Passport.js库来实现身份验证系统。Passport是一个灵活的Node.js身份验证中间件,它支持多种验证策略,例如用户名和密码、OAuth等。它提供了标准化的方法来为用户登录提供认证,是用户认证功能的常用解决方案。 5. **版权信息**:项目的版权声明表明了Sportscafe 2015是版权所有者,这表明项目或其相关内容最早发布于2015年或之前。这可能表明该API背后有商业实体的支持或授权使用。 从【压缩包子文件的文件名称列表】中我们可以了解到,该文件的版本控制仓库使用的是“master”分支。在Git版本控制系统中,“master”分支通常用于存放当前可部署的稳定版本代码。在“master”分支上进行的更改通常都是经过测试且准备发布到生产环境的。 综上所述,我们可以知道体育咖啡馆项目是一个利用现代JavaScript技术栈搭建的后端服务。它包含了处理HTTP请求的Express框架、连接MongoDB数据库的Mongoose库和实现用户身份验证的Passport.js中间件。该项目可用于构建提供体育信息、咖啡馆菜单信息、预约服务等的Web应用或API服务,这为体育咖啡馆的营销、用户体验和数据管理提供了可能。 考虑到文档资料的提及,该项目的安装和API文档应该包含在项目资料中,可能在项目的README文件或其他说明文档中。对于项目的使用者或者开发者而言,这部分文档非常重要,因为它们可以提供详细的信息和指导,帮助用户快速部署和使用该API。 总结来说,这是一套针对体育咖啡馆相关业务的后端解决方案,它使用了流行的开源技术栈,可以灵活地应用于体育咖啡馆的网络服务中,如信息发布、在线预约、会员管理等。
recommend-type

【LNR优化与用户体验】:一文看透互操作优化如何提升用户感知

# 摘要 LNR(邻区关系优化)是提升移动通信网络性能与用户体验的关键技术。本文系统阐述了LNR优化的基本概念、理论基础与技术实现路径,构建了包括信道质量评估、干扰建模及关键性能指标分析在内的数学模型,并探讨了参数配置、邻区管理及AI与大数据在LNR优化中的应用。通过量化用户感知指标,分析了LNR优化对视频流畅度、页面加载速度及切换性能的实际改善效果,并结合典型场景验证了其优化成效。本文进一步展望了LNR技术在5G及未来网络中的发展趋势与技术挑战。
recommend-type

Java1.8 的编程语言、使用场景、版本号、厂商、是否开源、发行日期、终止日期、可替代产品、推荐产品是什么

以下是关于 **Java 1.8** 的全面信息汇总,涵盖其编程语言特性、使用场景、版本号、厂商、是否开源、发行日期、终止支持日期、可替代产品和推荐产品。 --- ## ✅ Java 1.8 综合信息表 | 项目 | 内容 | |------|------| | **编程语言** | Java | | **版本号** | Java 1.8(也称为 Java 8) | | **厂商** | Oracle、Adoptium、Amazon(Corretto)、Azul(Zulu)、Red Hat、IBM 等 | | **是否开源** | ✅ 是(OpenJDK 1.8 是开源的,Oracle
recommend-type

Java开发的教区牧民支持系统介绍

根据给定文件信息,下面将详细阐述相关知识点: ### 标题知识点 #### catecumenus-java: 教区牧民支持系统 - **Java技术栈应用**:标题提到的“catecumenus-java”表明这是一个使用Java语言开发的系统。Java是目前最流行的编程语言之一,广泛应用于企业级应用、Web开发、移动应用等,尤其是在需要跨平台运行的应用中。Java被设计为具有尽可能少的实现依赖,所以它可以在多种处理器上运行。 - **教区牧民支持系统**:从标题来看,这个系统可能面向的是教会管理或教区管理,用来支持牧民(教会领导者或牧师)的日常管理工作。具体功能可能包括教友信息管理、教区活动安排、宗教教育资料库、财务管理、教堂资源调配等。 ### 描述知识点 #### 儿茶类 - **儿茶素(Catechin)**:描述中提到的“儿茶类”可能与“catecumenus”(新信徒、教徒)有关联,暗示这个系统可能与教会或宗教教育相关。儿茶素是一类天然的多酚类化合物,常见于茶、巧克力等植物中,具有抗氧化、抗炎等多种生物活性,但在系统标题中可能并无直接关联。 - **系统版本号**:“0.0.1”表示这是一个非常初期的版本,意味着该系统可能刚刚开始开发,功能尚不完善。 ### 标签知识点 #### Java - **Java语言特点**:标签中明确提到了“Java”,这暗示了整个系统都是用Java编程语言开发的。Java的特点包括面向对象、跨平台(即一次编写,到处运行)、安全性、多线程处理能力等。系统使用Java进行开发,可能看重了这些特点,尤其是在构建可扩展、稳定的后台服务。 - **Java应用领域**:Java广泛应用于企业级应用开发中,包括Web应用程序、大型系统后台、桌面应用以及移动应用(Android)。所以,此系统可能也会涉及这些技术层面。 ### 压缩包子文件的文件名称列表知识点 #### catecumenus-java-master - **Git项目结构**:文件名称中的“master”表明了这是Git版本控制系统中的一个主分支。在Git中,“master”分支通常被用作项目的主干,是默认的开发分支,所有开发工作都是基于此分支进行的。 - **项目目录结构**:在Git项目中,“catecumenus-java”文件夹应该包含了系统的源代码、资源文件、构建脚本、文档等。文件夹可能包含各种子文件夹和文件,比如src目录存放Java源代码,lib目录存放相关依赖库,以及可能的build.xml文件用于构建过程(如Ant或Maven构建脚本)。 ### 结合以上信息的知识点整合 综合以上信息,我们可以推断“catecumenus-java: 教区牧民支持系统”是一个使用Java语言开发的系统,可能正处于初级开发阶段。这个系统可能是为了支持教会内部管理,提供信息管理、资源调度等功能。其使用Java语言的目的可能是希望利用Java的多线程处理能力、跨平台特性和强大的企业级应用支持能力,以实现一个稳定和可扩展的系统。项目结构遵循了Git版本控制的规范,并且可能采用了模块化的开发方式,各个功能模块的代码和资源文件都有序地组织在不同的子文件夹内。 该系统可能采取敏捷开发模式,随着版本号的增加,系统功能将逐步完善和丰富。由于是面向教会的内部支持系统,对系统的用户界面友好性、安全性和数据保护可能会有较高的要求。此外,考虑到宗教性质的敏感性,系统的开发和使用可能还需要遵守特定的隐私和法律法规。
recommend-type

LNR切换成功率提升秘籍:参数配置到网络策略的全面指南

# 摘要 LNR(LTE to NR)切换技术是5G网络部署中的关键环节,直接影