Background overlay
1108 字
6 分钟
05-Pytorch 公共数据集合和DataLoader
2026-04-24
更新中...

一、torchvison 的标准数据集#

官方地址

import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()

二、DataLoader#

解决如何高效、规范地将数据从数据集取到dl中。PyTorch 提供了 torch.utils.data.DataLoader 这一强大的工具来解决数据的批量加载、打乱与多进程读取问题。本文将深入剖析 DataLoader 的底层逻辑、核心参数以及结合 TensorBoard 的综合应用场景。

1. 核心概念:Dataset 与 DataLoader 的协同逻辑#

在 PyTorch 的数据处理流中,DatasetDataLoader 是两个密不可分但分工不同的核心组件。

核心比喻

如果将数据集 (Dataset) 比作一副已经洗好或按顺序排列的扑克牌,那么 DataLoader 就是那个发牌的荷官

Dataset 负责告诉你“总共有多少张牌”以及“第 ii 张牌是什么”;而 DataLoader 则决定了“每次发几张牌(Batch)”、“发牌前要不要再洗一次牌(Shuffle)”以及“用几只手同时发牌(多进程)”。

通过 DataLoader,我们可以将零散的单条数据打包成具有统一维度(通常增加了一个 Batch 维度)的张量集合,从而利用 GPU 的并行计算能力加速训练。

2. 核心参数#

实例化 DataLoader 时,有几个至关重要的参数决定了数据加载的行为模式:

  • dataset (Dataset):数据源。必须是实现了 __len____getitem__ 方法的 Dataset 对象。

  • batch_size (int):批大小。决定了“每次发牌的数量”。例如 batch_size=64 意味着每次迭代会同时返回 64 张图片及其对应的标签。返回的图像张量形状将从 [C, H, W] 变为 [Batch, C, H, W]

  • shuffle (bool):是否打乱数据。

    • 设置为 True 时,每一个 Epoch 开始前,数据都会被重新打乱。这对于训练集(Train Set)至关重要,能有效防止模型记住数据的输入顺序,提升泛化能力。

    • 对于测试集(Test Set)通常设为 False

  • num_workers (int):加载数据时的子进程数量。

    • num_workers=0 表示只在主进程中加载数据(同步执行)。

    • 大于 0 时表示开启多进程异步加载,可以显著提高数据读取速度,缓解 GPU 等待 CPU 处理数据的问题。

  • drop_last (bool):是否丢弃最后一个不完整的 Batch。

    • 假设数据集有 100 张图片,batch_size 设为 30。前 3 次会每次取 30 张,最后剩下 10 张。

    • 若设为 True,则这最后的 10 张图片会被直接舍弃,不参与当前 Epoch 的计算;若设为 False(默认),则最后一次迭代会返回包含 10 张图片的 Batch。

3. 综合应用:结合 TensorBoard 的数据流可视化#

下面通过一段完整的标准代码,展示如何加载 CIFAR10 数据集,并通过 DataLoader 进行批量读取,最后利用 TensorBoard 可视化不同 Epoch 下数据加载的效果。

import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()

4. 注意事项#

  1. drop_last

    • 网络架构包含了某些对 Batch Size 严格敏感的层,最后一个非标准的 Batch 会导致形状不匹配从而报错。

    • 在训练阶段,当数据集数量无法整除 batch size 时,建议将其设为 drop_last=True,保证所有 Batch 的张量维度绝对统一;而在测试阶段,为了不漏掉任何一条测试数据,必须保持默认的 drop_last=False

  2. add_imageadd_images

    • Dataset 中直接取出的单条图像张量形状为 [C, H, W],应使用 writer.add_image();而从 DataLoader 中取出的批量图像张量形状为 [B, C, H, W]必须使用 writer.add_images(),TensorBoard 会自动将其拼接成网格形式展示。
  3. shuffle=True

    • shuffle=True 并不是在创建 DataLoader 时洗一次牌就结束了。它的底层机制是:在进入每一个新的 for epoch in range 循环时,DataLoader 都会重新洗牌,达到上图中每个epoch相同step但图片都不相同的效果。
05-Pytorch 公共数据集合和DataLoader
https://icemeow.top/blog/posts/graduate/pytorch-4/
作者
ICEMeow
发布于
2026-04-24
许可协议
CC BY-NC-SA 4.0