PyTorch 数据预处理利器:Transforms 核心用法与避坑指南
在 PyTorch 深度学习项目中,数据预处理是至关重要的一环。torchvision.transforms 为我们提供了一个强大的“图像处理工具箱”。本文将系统性地梳理 Transforms 的核心概念、常用函数以及在实际开发中的避坑技巧。
一、 Transforms 核心概念
你可以把 transforms 想象成一个装满各种图像处理工具的工具箱。使用它的基本流程分为三步:
-
挑选工具:从工具箱中拿出一个工具(即实例化一个 Transform 类)。
-
放入原料:传入需要处理的图像数据。
-
获取成品:得到处理后的输出结果。
关键底层逻辑:Python 的 __call__ 魔术方法
很多初学者会疑惑,为什么实例化一个类之后,可以直接像调用函数一样向里面传入参数?例如:
tensor_trans = transforms.ToTensor()tensor_img = tensor_trans(img) # 为什么可以直接调用?这是因为 Transforms 工具箱中的类都实现了 Python 的 __call__ 魔术方法。任何实现了该方法的对象,都可以被当做函数直接调用,这使得代码更加简洁优雅。
二、 数据类型的流转
在使用 Transforms 时,最核心的法则是时刻关注数据的输入(Input)和输出(Output)类型。在深度学习图像处理中,主要涉及到以下三种数据类型:
-
PIL Image:通过
PIL.Image.open()读取,常规的图像对象。 -
numpy.ndarray:通过
cv2.imread()读取,OpenCV 默认格式。 -
Tensor:PyTorch 的张量格式。它不仅包含图像的矩阵数据,还包含了神经网络训练所需的关键参数(如梯度属性
grad、设备属性device等)。
三、 常用 Transforms 函数详解
1. 基础转换:ToTensor
- 作用:将
PIL Image或numpy.ndarray格式的图片,转换为 PyTorch 训练必备的Tensor格式。同时会将像素值从 “ 缩放到[0.0, 1.0]。
from torchvision import transformsfrom PIL import Image
# 读取图片 (PIL Image)img_path = "dataset/train/ants_image/0013035.jpg"img = Image.open(img_path)
# 实例化并调用 ToTensortensor_trans = transforms.ToTensor()tensor_img = tensor_trans(img)
2. 数据标准化:Normalize
-
作用:用指定的均值(Mean)和标准差(Std)对 Tensor 图像进行归一化处理。这有助于加快神经网络的收敛速度。
-
计算公式:
output[channel] = (input[channel] - mean[channel]) / std[channel] -
注意:此函数的输入和输出都必须是 Tensor。
# 设置 RGB 3个通道的均值和标准差 (这里以 0.5 为例)trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 传入刚才由 ToTensor 转换得到的 tensor_imgimg_norm = trans_norm(tensor_img)
3. 尺寸调整:Resize
-
作用:将图片缩放为指定的尺寸大小。
-
注意:此函数的输入通常是 PIL Image,输出也是 PIL Image。
# 用法1:传入元组 (H, W),强制缩放为指定的高和宽trans_resize = transforms.Resize((512, 512))img_resize = trans_resize(img)
# 用法2:传入一个整数,会将图片的短边缩放至该数值,长边保持原比例缩放trans_resize_2 = transforms.Resize(512)4. 数据增强:RandomCrop
- 作用:在图片的随机位置裁剪出指定大小的图片,常用于数据增强,防止模型过拟合。
# 随机裁剪出 512x512 大小的图片trans_random = transforms.RandomCrop(512)
# 也可以传入元组指定高宽# trans_random = transforms.RandomCrop((500, 1000))四、 打造流水线:Compose
在实际项目中,我们往往需要对一张图片进行多步操作(例如:先缩放 -> 再裁剪 -> 最后转为 Tensor)。transforms.Compose 可以将这些操作像流水线一样打包依次执行。
核心原则:流水线中,前一个操作的输出数据类型,必须严格等于后一个操作所需的输入数据类型。
# 1. 准备所需的操作trans_resize = transforms.Resize(512)trans_random = transforms.RandomCrop(512)trans_totensor = transforms.ToTensor()
# 2. 将操作打包成流水线(注意参数是一个列表)pipeline = transforms.Compose([ trans_resize, # PIL -> PIL trans_random, # PIL -> PIL trans_totensor # PIL -> Tensor])
# 3. 执行流水线final_tensor_img = pipeline(img)五、 可视化辅助验证
为了验证 Transforms 的处理效果,我们可以结合 Tensorboard 进行直观的图像查看:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# 添加 ToTensor 后的图片writer.add_image("1_ToTensor", tensor_img, 1)
# 添加 Normalize 后的图片 (颜色通常会发生明显变化)writer.add_image("2_Normalize", img_norm, 1)
writer.close()正在加载评论...
链上评论区