Background overlay
1063 字
5 分钟
04-Pytorch 数据预处理Transform
2026-04-22
更新中...

PyTorch 数据预处理利器:Transforms 核心用法与避坑指南#

在 PyTorch 深度学习项目中,数据预处理是至关重要的一环。torchvision.transforms 为我们提供了一个强大的“图像处理工具箱”。本文将系统性地梳理 Transforms 的核心概念、常用函数以及在实际开发中的避坑技巧。

一、 Transforms 核心概念#

你可以把 transforms 想象成一个装满各种图像处理工具的工具箱。使用它的基本流程分为三步:

  1. 挑选工具:从工具箱中拿出一个工具(即实例化一个 Transform 类)。

  2. 放入原料:传入需要处理的图像数据。

  3. 获取成品:得到处理后的输出结果。

关键底层逻辑:Python 的 __call__ 魔术方法#

很多初学者会疑惑,为什么实例化一个类之后,可以直接像调用函数一样向里面传入参数?例如:

tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img) # 为什么可以直接调用?

这是因为 Transforms 工具箱中的类都实现了 Python 的 __call__ 魔术方法。任何实现了该方法的对象,都可以被当做函数直接调用,这使得代码更加简洁优雅。

二、 数据类型的流转#

在使用 Transforms 时,最核心的法则是时刻关注数据的输入(Input)和输出(Output)类型。在深度学习图像处理中,主要涉及到以下三种数据类型:

  • PIL Image:通过 PIL.Image.open() 读取,常规的图像对象。

  • numpy.ndarray:通过 cv2.imread() 读取,OpenCV 默认格式。

  • Tensor:PyTorch 的张量格式。它不仅包含图像的矩阵数据,还包含了神经网络训练所需的关键参数(如梯度属性 grad、设备属性 device 等)。

三、 常用 Transforms 函数详解#

1. 基础转换:ToTensor#

  • 作用:将 PIL Imagenumpy.ndarray 格式的图片,转换为 PyTorch 训练必备的 Tensor 格式。同时会将像素值从 “ 缩放到 [0.0, 1.0]
from torchvision import transforms
from PIL import Image
# 读取图片 (PIL Image)
img_path = "dataset/train/ants_image/0013035.jpg"
img = Image.open(img_path)
# 实例化并调用 ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)

162

2. 数据标准化:Normalize#

  • 作用:用指定的均值(Mean)和标准差(Std)对 Tensor 图像进行归一化处理。这有助于加快神经网络的收敛速度。

  • 计算公式output[channel] = (input[channel] - mean[channel]) / std[channel]

  • 注意:此函数的输入和输出都必须是 Tensor

# 设置 RGB 3个通道的均值和标准差 (这里以 0.5 为例)
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 传入刚才由 ToTensor 转换得到的 tensor_img
img_norm = trans_norm(tensor_img)

203

3. 尺寸调整:Resize#

  • 作用:将图片缩放为指定的尺寸大小。

  • 注意:此函数的输入通常是 PIL Image,输出也是 PIL Image

# 用法1:传入元组 (H, W),强制缩放为指定的高和宽
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img)
# 用法2:传入一个整数,会将图片的短边缩放至该数值,长边保持原比例缩放
trans_resize_2 = transforms.Resize(512)

4. 数据增强:RandomCrop#

  • 作用:在图片的随机位置裁剪出指定大小的图片,常用于数据增强,防止模型过拟合。
# 随机裁剪出 512x512 大小的图片
trans_random = transforms.RandomCrop(512)
# 也可以传入元组指定高宽
# trans_random = transforms.RandomCrop((500, 1000))

四、 打造流水线:Compose#

在实际项目中,我们往往需要对一张图片进行多步操作(例如:先缩放 -> 再裁剪 -> 最后转为 Tensor)。transforms.Compose 可以将这些操作像流水线一样打包依次执行。

核心原则:流水线中,前一个操作的输出数据类型,必须严格等于后一个操作所需的输入数据类型。

# 1. 准备所需的操作
trans_resize = transforms.Resize(512)
trans_random = transforms.RandomCrop(512)
trans_totensor = transforms.ToTensor()
# 2. 将操作打包成流水线(注意参数是一个列表)
pipeline = transforms.Compose([
trans_resize, # PIL -> PIL
trans_random, # PIL -> PIL
trans_totensor # PIL -> Tensor
])
# 3. 执行流水线
final_tensor_img = pipeline(img)

五、 可视化辅助验证#

为了验证 Transforms 的处理效果,我们可以结合 Tensorboard 进行直观的图像查看:

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# 添加 ToTensor 后的图片
writer.add_image("1_ToTensor", tensor_img, 1)
# 添加 Normalize 后的图片 (颜色通常会发生明显变化)
writer.add_image("2_Normalize", img_norm, 1)
writer.close()
04-Pytorch 数据预处理Transform
https://icemeow.top/blog/posts/graduate/pytorch-3/
作者
ICEMeow
发布于
2026-04-22
许可协议
CC BY-NC-SA 4.0