pytorch花卉识别
时间: 2025-02-16 15:06:11 AIGC 浏览: 85
### 使用 PyTorch 实现花卉图像分类
#### 数据准备与加载
为了构建一个有效的花卉图像分类器,首先需要准备好数据集并对其进行适当处理。这里采用的是Flowers-102数据集[^1]。通过`torchvision.datasets.ImageFolder`来加载图片,并利用`transforms.Compose`定义一系列的数据转换操作。
```python
from torchvision import datasets, transforms
import torch.utils.data as data
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(root=f'dataset/{x}', transform=data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: data.DataLoader(image_datasets[x], batch_size=32, shuffle=True) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
```
#### 构建模型
接下来是创建用于分类的神经网络模型。此处选择了经典的残差网络——ResNet作为基础框架。可以通过修改预训练权重与否来自定义初始化方式。
```python
import torchvision.models as models
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
```
#### 训练过程
设置好损失函数、优化算法之后就可以进入正式的训练环节了。对于每一轮迭代,在训练模式下更新参数;而在验证阶段则评估当前模型性能。
```python
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train(True)
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
scheduler.step()
return model
```
完成上述步骤后,调用`train_model()`函数启动整个流程即可得到经过充分训练后的花卉识别模型。
#### 测试与预测
当模型已经过良好训练后,便可以将其应用于新的未知样本上进行预测。除了基本的功能外,还提供了TTA(Test Time Augmentation)[^2] 和 CAM(Class Activation Mapping)[^2] 功能以提高准确性并解释决策依据。
最后提醒一点关于开发环境的选择建议使用Anaconda管理依赖关系,并搭配PyCharm或VSCode这类IDE来进行编码工作[^3]。
阅读全文
相关推荐




















