Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

      • DataLoader 与 DataSet
        • DataLoader 与 DataSet
        • Dataloader
          • 参数
        • Dataset
          • 参数说明
          • 需要实现的方法
          • 示例
          • 其他常见方法(可选实现)
          • 注意事项
        • PyTorch 数据读取流程图
          • 通过 DataLoader 获取数据
        • torchvision
        • pytorch 自带的数据集
          • torchversion.datasets
          • MNIST 数据集的介绍
      • torchvision.transforms
      • torch.utils.data
    • 模型

    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 数据处理
Geeks_Z
2022-10-30
目录

DataLoader 与 DataSet

DataLoader 与 DataSet

PyTorch 数据读入是通过 Dataset+DataLoader 的方式完成的,Dataset 定义好数据的格式和数据变换形式,DataLoader 用 iterative 的方式不断读入批次数据。

torch.utils.data (opens new window)

PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。

数据模块又可以细分为 4 个部分:

  • 数据收集:样本和标签。
  • 数据划分:训练集、验证集和测试集
  • 数据读取:对应于 PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
  • 数据预处理:对应于 PyTorch 的 transforms

image-20220912180853467

Dataloader

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
1

https://gitee.com/geeks_z/upload_images/raw/master/202112212156941.png

参数

  • dataset (Dataset) – 加载数据的数据集,Dataset 类,决定数据从哪里读取以及如何读取。
  • batch_size (int, optional) – 每个 batch 加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个 epoch 重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
  • collate_fn (callable, optional) –
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的 batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个 batch 将更小。(默认: False)

DataLoader的使用方法示例:

from torch.utils.data import DataLoader

dataset = CifarDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)

#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)
1
2
3
4
5
6
7
8
9

数据迭代器的返回结果如下:

555 ('spam', 'ham', 'spam', 'ham', 'ham', 'ham', 'ham', 'spam', 'ham', 'ham') ('URGENT! We are trying to contact U. Todays draw shows that you have won a £800 prize GUARANTEED. Call 09050003091 from....", 'swhrt how u dey,hope ur ok, tot about u 2day.love n miss.take care.')
***********************************************************************************
556 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') ('He telling not to tell any one. If so treat for me hi hi hi', 'Did u got that persons story', "Don kn....1000 cash prize or a prize worth £5000')
1
2
3

注意:

  1. len(dataset) = 数据集的样本数
  2. len(data_loader) = math.ceil(样本数/batch_size) 即向上取整

Dataset

torch.utils.data.Dataset 是 PyTorch 中的一个抽象类,用于表示一个数据集。当你想要创建自己的数据集时,你需要继承这个类并实现至少两个方法:__len__ 和 __getitem__。

参数说明

torch.utils.data.Dataset 本身并没有直接的参数,因为它是一个抽象基类,需要子类实现具体的方法。

需要实现的方法

  1. __len__(self):

    • 返回数据集的大小(即数据项的总数)。
    • 当你使用 len(dataset) 时,这个方法会被调用。
  2. __getitem__(self, index):

    • 根据提供的索引返回单个数据项。
    • 当你使用 dataset[index] 时,这个方法会被调用。
    • 返回的数据项通常是一个元组,包含输入数据和标签(如果有的话)。

示例

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 使用示例
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 获取数据集的大小
print(len(dataset))  # 输出: 5

# 获取索引为2的数据项
print(dataset[2])  # 输出: 3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

其他常见方法(可选实现)

虽然 __len__ 和 __getitem__ 是必需的,但你可能还想实现其他方法以提供更多的功能,例如:

  • __init__: 用于向类中传入外部参数,同时定义样本集
  • __add__(self, other): 实现数据集的加法操作,使得你可以合并两个数据集。
  • transform(self, fn): 定义一个转换方法,该方法接受一个函数 fn 并返回一个新的数据集,其中每个数据项都经过 fn 的处理。

注意事项

  • 当创建自定义数据集时,确保你的 __getitem__ 方法返回的数据类型与你的模型期望的输入类型相匹配。
  • 如果你想要进行批量处理,可以考虑使用 torch.utils.data.DataLoader,它可以与你的 Dataset 子类一起使用。

PyTorch 数据读取流程图

Untitled

首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调DatasetFetcher 根据index获取数据。在DatasetFetcher里会调用Dataset的__getitem__()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser。

通过 DataLoader 获取数据

for i, (_, inputs, targets) in enumerate(sample_loader):
1

其中 (_, inputs, targets)与DummyDataset中的__getitem__返回值对应

class DummyDataset(Dataset):
    def __init__(self, images, labels, trsf, use_path=False):
        ……

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_path:
            image = self.trsf(pil_loader(self.images[idx]))
        else:
            image = self.trsf(Image.fromarray(self.images[idx]))
        label = self.labels[idx]

        return idx, image, label
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

torchvision

torchvision 是独立于 pytorch 的关于图像操作的一些方便工具库。

torchvision 的详细介绍在:https://pypi.org/project/torchvision/

torchvision 主要包括以下几个包:

  • vision.datasets (opens new window) : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的 Dataset 的子类
  • vision.models (opens new window) : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
  • vision.transforms (opens new window) : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到 tensor ,numpy 数组到 tensor , tensor 到 图像等。
  • vision.utils (opens new window) : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个 mini-batch 的图像可以产生一个图像格网。

pytorch 自带的数据集

pytorch 中自带的数据集由两个上层 api 提供,分别是torchvision和torchtext

其中:

  1. torchvision提供了对图片数据处理相关的 api 和数据
    • 数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的 API 和数据
    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

下面我们以 Mnist 手写数字为例,来看看 pytorch 如何加载其中自带的数据集

使用方法和之前一样:

  1. 准备好 Dataset 实例
  2. 把 dataset 交给 dataloder 打乱顺序,组成 batch

torchversion.datasets

torchversoin.datasets中的数据集类(比如torchvision.datasets.MNIST),都是继承自Dataset

意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例

但是 MNIST API 中的参数需要注意一下:

torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)

  1. root参数表示数据存放的位置
  2. train:bool 类型,表示是使用训练集的数据还是测试集的数据
  3. download:bool 类型,表示是否需要下载数据到 root 目录
  4. transform:实现的对图片的处理函数

MNIST 数据集的介绍

数据集的原始地址:http://yann.lecun.com/exdb/mnist/

MNIST 是由Yann LeCun等人提供的免费的图像识别的数据集,其中包括 60000 个训练样本和 10000 个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28

执行代码,下载数据,观察数据类型:

import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])
1
2
3

下载的数据如下:

Untitled

代码输出结果如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
(<PIL.Image.Image image mode=L size=28x28 at 0x18D303B9C18>, tensor(5))
1
2
3
4
5
6
7

可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值

返回值的第 0 个为 Image 类型,可以调用 show() 方法打开,发现为手写数字 5

import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])
img = dataset[0][0]
img.show() #打开图片
1
2
3
4
5

图片如下:

由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到。

#PyTorch
上次更新: 2025/06/25, 11:25:50
tensor可视化为图片
torchvision.transforms

← tensor可视化为图片 torchvision.transforms→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式