图像字幕生成:从模型训练到结果展示
立即解锁
发布时间: 2025-08-30 00:38:26 阅读量: 5 订阅数: 10 AIGC 


PyTorch Lightning实战指南
### 图像字幕生成:从模型训练到结果展示
#### 1. 前期准备
在开始模型训练之前,有一些重要的前期准备工作。首先,需要将 `<pad>` 标记添加到词汇表中,并且要在添加其他标记之前进行。这样做可以确保该标记的整数值被指定为 0,这与 `coco_collate_fn` 中的编程逻辑一致,在创建一批填充字幕时会直接使用零(`torch.zeros()`)。最后,使用 `pickle.dump` 方法将词汇表持久化到 `coco_data` 目录中。
下载数据集和组装数据是一次性的处理步骤。如果要重新运行模型以恢复或重新开始训练,则不需要重复这些步骤,可以直接从后续步骤开始。
#### 2. 模型训练
模型训练涉及多个步骤,包括使用 `torch.utils.data.Dataset` 和 `torch.utils.data.DataLoader` 类加载数据,使用 `pytorch_lightning.LightningModule` 类定义模型,设置训练配置,并使用 PyTorch Lightning 框架的 `Trainer` 启动训练过程。
##### 2.1 导入必要的包
在 `Training_the_model.ipynb` 笔记本的第一个单元格中,导入必要的包:
```python
import os
import json
import pickle
import nltk
from PIL import Image
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import pytorch_lightning as pl
from model import HybridModel
from vocabulary import Vocabulary
```
##### 2.2 定义数据集类
定义 `CocoDataset` 类,它继承自 `torch.utils.data.Dataset` 类。这是一个映射式数据集,需要定义 `__getitem__()` 和 `__len__()` 方法:
```python
class CocoDataset(data.Dataset):
def __init__(self, data_path, json_path, vocabulary, transform=None):
self.image_dir = data_path
self.vocabulary = vocabulary
self.transform = transform
with open(json_path) as json_file:
self.coco = json.load(json_file)
self.image_id_file_name = dict()
for image in self.coco['images']:
self.image_id_file_name[image['id']] = image['file_name']
def __getitem__(self, idx):
annotation = self.coco['annotations'][idx]
caption = annotation['caption']
tkns = nltk.tokenize.word_tokenize(str(caption).lower())
caption = []
caption.append(self.vocabulary('<start>'))
caption.extend([self.vocabulary(tkn) for tkn in tkns])
caption.append(self.vocabulary('<end>'))
image_id = annotation['image_id']
image_file = self.image_id_file_name[image_id]
image_path = os.path.join(self.image_dir, image_file)
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, torch.Tensor(caption)
def __len__(self):
return len(self.coco['annotations'])
```
需要注意的是,COCO 数据集每张图像有五个字幕。由于模型处理每个图像 - 字幕对,因此数据集的长度等于字幕的总数,而不是图像的总数。
##### 2.3 定义整理函数
在 `Training_the_model.ipynb` 笔记本的下一个单元格中,定义 `coco_collate_fn()` 整理函数,用于对一批图像和对应的字幕进行填充:
```python
def coco_collate_fn(data_batch):
data_batch.sort(key=lambda d: len(d[1]), reverse=True)
imgs, caps = zip(*data_batch)
imgs = torch.stack(imgs, 0)
cap_lens = [len(cap) for cap in caps]
padded_caps = torch.zeros(len(caps), max(cap_lens)).long()
for i, cap in enumerate(caps):
end = cap_lens[i]
padded_caps[i, :end] = cap[:end]
return imgs, padded_caps, cap_lens
```
##### 2.4 定义数据加载器函数
在 `Training_the_model.ipynb` 笔记本的下一个单元格中,定义 `get_loader()` 函数,用于创建数据加载器:
```python
def get_loader(data_path, json_path, vocabulary, transform, batch_size, shuffle, num_workers=0):
coco_ds = CocoDataset(data_path=data_path,
json_path=json_path,
vocabulary=vocabulary,
transform=transform)
coco_dl = data.DataLoader(dataset=coco_ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
```
0
0
复制全文
相关推荐










