小白记录:transforms.Compose
是 PyTorch 中 torchvision.transforms
模块的核心工具,用于将多个数据预处理/增强操作组合成一个可统一执行的变换管道。它的主要作用和特点如下:
核心作用
-
串联多个变换操作
将多个独立的图像预处理/增强步骤(如缩放、裁剪、归一化等)按顺序组合成一个整体变换流程。 -
确保执行顺序
严格按照Compose
中列出的顺序依次执行每个变换(例如:先调整大小 → 再裁剪 → 最后归一化)。
为什么需要它?
在训练深度学习模型(尤其是图像任务)时,数据通常需要经过一系列预处理步骤,例如:
-
调整图像尺寸(
Resize
) -
随机裁剪(
RandomCrop
) -
转换为张量(
ToTensor
) -
标准化像素值(
Normalize
)
Compose
将这些步骤封装为一个整体,简化代码并确保逻辑清晰。
使用示例
python
复制
下载
from torchvision import transforms # 定义组合变换 transform = transforms.Compose([ transforms.Resize(256), # 步骤1: 调整图像大小为256x256 transforms.RandomCrop(224), # 步骤2: 随机裁剪224x224区域 transforms.RandomHorizontalFlip(), # 步骤3: 随机水平翻转(数据增强) transforms.ToTensor(), # 步骤4: 将图像转为Tensor格式(通道优先) transforms.Normalize( # 步骤5: 标准化像素值 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 应用变换 transformed_image = transform(original_image)
关键特性
-
顺序敏感
操作顺序直接影响结果。例如:-
ToTensor()
必须在Normalize()
之前(因为归一化需要 Tensor 输入)。 -
数据增强(如翻转)应在转换为 Tensor 之前进行。
-
-
支持自定义变换
可与自定义函数/类结合使用:python
复制
下载
transforms.Compose([ transforms.Lambda(lambda x: x.rotate(90)), # 自定义旋转 transforms.ToTensor() ])
-
兼容数据加载器
可直接传入torch.utils.data.DataLoader
:python
复制
下载
dataset = torchvision.datasets.ImageFolder( root="data_dir", transform=transform # 使用Compose定义的变换 )
常见变换操作
变换类型 | 代表函数 |
---|---|
尺寸调整 | Resize() , CenterCrop() |
数据增强 | RandomHorizontalFlip() , ColorJitter() |
类型转换 | ToTensor() , ToPILImage() |
数值归一化 | Normalize() |
自定义操作 | Lambda() |
总结
transforms.Compose
的核心价值在于:
-
代码模块化:将零散的预处理步骤整合为一个逻辑单元。
-
确保可重复性:固定的变换顺序保证每次处理一致。
-
无缝对接PyTorch流程:简化数据加载到模型训练的流程。
它是构建高效数据预处理管道的标准工具,尤其在计算机视觉任务中不可或缺。