一、torchvison 的标准数据集
import torchvisionfrom torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])# print(test_set.classes)## img, target = test_set[0]# print(img)# print(target)# print(test_set.classes[target])# img.show()## print(test_set[0])
writer = SummaryWriter("p10")for i in range(10): img, target = test_set[i] writer.add_image("test_set", img, i)
writer.close()
二、DataLoader
解决如何高效、规范地将数据从数据集取到dl中。PyTorch 提供了 torch.utils.data.DataLoader 这一强大的工具来解决数据的批量加载、打乱与多进程读取问题。本文将深入剖析 DataLoader 的底层逻辑、核心参数以及结合 TensorBoard 的综合应用场景。
1. 核心概念:Dataset 与 DataLoader 的协同逻辑
在 PyTorch 的数据处理流中,Dataset 和 DataLoader 是两个密不可分但分工不同的核心组件。
核心比喻:
如果将数据集 (
Dataset) 比作一副已经洗好或按顺序排列的扑克牌,那么 DataLoader 就是那个发牌的荷官。
Dataset负责告诉你“总共有多少张牌”以及“第 张牌是什么”;而DataLoader则决定了“每次发几张牌(Batch)”、“发牌前要不要再洗一次牌(Shuffle)”以及“用几只手同时发牌(多进程)”。
通过 DataLoader,我们可以将零散的单条数据打包成具有统一维度(通常增加了一个 Batch 维度)的张量集合,从而利用 GPU 的并行计算能力加速训练。
2. 核心参数
实例化 DataLoader 时,有几个至关重要的参数决定了数据加载的行为模式:
-
dataset(Dataset):数据源。必须是实现了__len__和__getitem__方法的 Dataset 对象。 -
batch_size(int):批大小。决定了“每次发牌的数量”。例如batch_size=64意味着每次迭代会同时返回 64 张图片及其对应的标签。返回的图像张量形状将从[C, H, W]变为[Batch, C, H, W]。 -
shuffle(bool):是否打乱数据。-
设置为
True时,每一个 Epoch 开始前,数据都会被重新打乱。这对于训练集(Train Set)至关重要,能有效防止模型记住数据的输入顺序,提升泛化能力。 -
对于测试集(Test Set)通常设为
False。
-
-
num_workers(int):加载数据时的子进程数量。-
num_workers=0表示只在主进程中加载数据(同步执行)。 -
大于 0 时表示开启多进程异步加载,可以显著提高数据读取速度,缓解 GPU 等待 CPU 处理数据的问题。
-
-
drop_last(bool):是否丢弃最后一个不完整的 Batch。-
假设数据集有 100 张图片,
batch_size设为 30。前 3 次会每次取 30 张,最后剩下 10 张。 -
若设为
True,则这最后的 10 张图片会被直接舍弃,不参与当前 Epoch 的计算;若设为False(默认),则最后一次迭代会返回包含 10 张图片的 Batch。
-
3. 综合应用:结合 TensorBoard 的数据流可视化
下面通过一段完整的标准代码,展示如何加载 CIFAR10 数据集,并通过 DataLoader 进行批量读取,最后利用 TensorBoard 可视化不同 Epoch 下数据加载的效果。
import torchvision
# 准备的测试数据集from torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及targetimg, target = test_data[0]print(img.shape)print(target)
writer = SummaryWriter("dataloader")for epoch in range(2): step = 0 for data in test_loader: imgs, targets = data # print(imgs.shape) # print(targets) writer.add_images("Epoch: {}".format(epoch), imgs, step) step = step + 1
writer.close()
4. 注意事项
-
drop_last-
网络架构包含了某些对 Batch Size 严格敏感的层,最后一个非标准的 Batch 会导致形状不匹配从而报错。
-
在训练阶段,当数据集数量无法整除 batch size 时,建议将其设为
drop_last=True,保证所有 Batch 的张量维度绝对统一;而在测试阶段,为了不漏掉任何一条测试数据,必须保持默认的drop_last=False。
-
-
add_image与add_images- 从
Dataset中直接取出的单条图像张量形状为[C, H, W],应使用writer.add_image();而从DataLoader中取出的批量图像张量形状为[B, C, H, W],必须使用writer.add_images(),TensorBoard 会自动将其拼接成网格形式展示。
- 从
-
shuffle=Trueshuffle=True并不是在创建DataLoader时洗一次牌就结束了。它的底层机制是:在进入每一个新的for epoch in range循环时,DataLoader 都会重新洗牌,达到上图中每个epoch相同step但图片都不相同的效果。
正在加载评论...
链上评论区